fix(router.py): instrument pre-call-checks for all openai endpoints

This commit is contained in:
Krrish Dholakia 2024-04-18 21:54:25 -07:00
parent 3b9e2a58e2
commit 9c42c847a5
3 changed files with 130 additions and 10 deletions

View file

@ -379,6 +379,9 @@ class Router:
else:
model_client = potential_model_client
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
self.routing_strategy_pre_call_checks(deployment=deployment)
response = litellm.completion(
**{
**data,
@ -391,6 +394,7 @@ class Router:
verbose_router_logger.info(
f"litellm.completion(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
@ -501,10 +505,12 @@ class Router:
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.routing_strategy_pre_call_checks(deployment=deployment)
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await _response
else:
await self.routing_strategy_pre_call_checks(deployment=deployment)
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await _response
self.success_calls[model_name] += 1
@ -579,6 +585,10 @@ class Router:
model_client = potential_model_client
self.total_calls[model_name] += 1
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
self.routing_strategy_pre_call_checks(deployment=deployment)
response = litellm.image_generation(
**{
**data,
@ -657,7 +667,7 @@ class Router:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = await litellm.aimage_generation(
response = litellm.aimage_generation(
**{
**data,
"prompt": prompt,
@ -666,6 +676,28 @@ class Router:
**kwargs,
}
)
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.aimage_generation(model={model_name})\033[32m 200 OK\033[0m"
@ -757,7 +789,7 @@ class Router:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = await litellm.atranscription(
response = litellm.atranscription(
**{
**data,
"file": file,
@ -766,6 +798,28 @@ class Router:
**kwargs,
}
)
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.atranscription(model={model_name})\033[32m 200 OK\033[0m"
@ -979,7 +1033,8 @@ class Router:
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = await litellm.atext_completion(
response = litellm.atext_completion(
**{
**data,
"prompt": prompt,
@ -989,6 +1044,27 @@ class Router:
**kwargs,
}
)
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.atext_completion(model={model_name})\033[32m 200 OK\033[0m"
@ -1063,6 +1139,10 @@ class Router:
model_client = potential_model_client
self.total_calls[model_name] += 1
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
self.routing_strategy_pre_call_checks(deployment=deployment)
response = litellm.embedding(
**{
**data,
@ -1147,7 +1227,7 @@ class Router:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = await litellm.aembedding(
response = litellm.aembedding(
**{
**data,
"input": input,
@ -1156,6 +1236,28 @@ class Router:
**kwargs,
}
)
### CONCURRENCY-SAFE RPM CHECKS ###
rpm_semaphore = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.aembedding(model={model_name})\033[32m 200 OK\033[0m"
@ -1713,7 +1815,23 @@ class Router:
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models
async def routing_strategy_pre_call_checks(self, deployment: dict):
def routing_strategy_pre_call_checks(self, deployment: dict):
"""
Mimics 'async_routing_strategy_pre_call_checks'
Ensures consistent update rpm implementation for 'usage-based-routing-v2'
Returns:
- None
Raises:
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
"""
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger):
response = _callback.pre_call_check(deployment)
async def async_routing_strategy_pre_call_checks(self, deployment: dict):
"""
For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.