mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(router.py): instrument pre-call-checks for all openai endpoints
This commit is contained in:
parent
3b9e2a58e2
commit
9c42c847a5
3 changed files with 130 additions and 10 deletions
|
@ -379,6 +379,9 @@ class Router:
|
|||
else:
|
||||
model_client = potential_model_client
|
||||
|
||||
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
|
||||
self.routing_strategy_pre_call_checks(deployment=deployment)
|
||||
|
||||
response = litellm.completion(
|
||||
**{
|
||||
**data,
|
||||
|
@ -391,6 +394,7 @@ class Router:
|
|||
verbose_router_logger.info(
|
||||
f"litellm.completion(model={model_name})\033[32m 200 OK\033[0m"
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_router_logger.info(
|
||||
|
@ -501,10 +505,12 @@ class Router:
|
|||
- Check rpm limits before making the call
|
||||
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||
"""
|
||||
await self.routing_strategy_pre_call_checks(deployment=deployment)
|
||||
await self.async_routing_strategy_pre_call_checks(
|
||||
deployment=deployment
|
||||
)
|
||||
response = await _response
|
||||
else:
|
||||
await self.routing_strategy_pre_call_checks(deployment=deployment)
|
||||
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
|
||||
response = await _response
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
|
@ -579,6 +585,10 @@ class Router:
|
|||
model_client = potential_model_client
|
||||
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
|
||||
self.routing_strategy_pre_call_checks(deployment=deployment)
|
||||
|
||||
response = litellm.image_generation(
|
||||
**{
|
||||
**data,
|
||||
|
@ -657,7 +667,7 @@ class Router:
|
|||
model_client = potential_model_client
|
||||
|
||||
self.total_calls[model_name] += 1
|
||||
response = await litellm.aimage_generation(
|
||||
response = litellm.aimage_generation(
|
||||
**{
|
||||
**data,
|
||||
"prompt": prompt,
|
||||
|
@ -666,6 +676,28 @@ class Router:
|
|||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
### CONCURRENCY-SAFE RPM CHECKS ###
|
||||
rpm_semaphore = self._get_client(
|
||||
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
||||
)
|
||||
|
||||
if rpm_semaphore is not None and isinstance(
|
||||
rpm_semaphore, asyncio.Semaphore
|
||||
):
|
||||
async with rpm_semaphore:
|
||||
"""
|
||||
- Check rpm limits before making the call
|
||||
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||
"""
|
||||
await self.async_routing_strategy_pre_call_checks(
|
||||
deployment=deployment
|
||||
)
|
||||
response = await response
|
||||
else:
|
||||
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
|
||||
response = await response
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
verbose_router_logger.info(
|
||||
f"litellm.aimage_generation(model={model_name})\033[32m 200 OK\033[0m"
|
||||
|
@ -757,7 +789,7 @@ class Router:
|
|||
model_client = potential_model_client
|
||||
|
||||
self.total_calls[model_name] += 1
|
||||
response = await litellm.atranscription(
|
||||
response = litellm.atranscription(
|
||||
**{
|
||||
**data,
|
||||
"file": file,
|
||||
|
@ -766,6 +798,28 @@ class Router:
|
|||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
### CONCURRENCY-SAFE RPM CHECKS ###
|
||||
rpm_semaphore = self._get_client(
|
||||
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
||||
)
|
||||
|
||||
if rpm_semaphore is not None and isinstance(
|
||||
rpm_semaphore, asyncio.Semaphore
|
||||
):
|
||||
async with rpm_semaphore:
|
||||
"""
|
||||
- Check rpm limits before making the call
|
||||
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||
"""
|
||||
await self.async_routing_strategy_pre_call_checks(
|
||||
deployment=deployment
|
||||
)
|
||||
response = await response
|
||||
else:
|
||||
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
|
||||
response = await response
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
verbose_router_logger.info(
|
||||
f"litellm.atranscription(model={model_name})\033[32m 200 OK\033[0m"
|
||||
|
@ -979,7 +1033,8 @@ class Router:
|
|||
else:
|
||||
model_client = potential_model_client
|
||||
self.total_calls[model_name] += 1
|
||||
response = await litellm.atext_completion(
|
||||
|
||||
response = litellm.atext_completion(
|
||||
**{
|
||||
**data,
|
||||
"prompt": prompt,
|
||||
|
@ -989,6 +1044,27 @@ class Router:
|
|||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
rpm_semaphore = self._get_client(
|
||||
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
||||
)
|
||||
|
||||
if rpm_semaphore is not None and isinstance(
|
||||
rpm_semaphore, asyncio.Semaphore
|
||||
):
|
||||
async with rpm_semaphore:
|
||||
"""
|
||||
- Check rpm limits before making the call
|
||||
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||
"""
|
||||
await self.async_routing_strategy_pre_call_checks(
|
||||
deployment=deployment
|
||||
)
|
||||
response = await response
|
||||
else:
|
||||
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
|
||||
response = await response
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
verbose_router_logger.info(
|
||||
f"litellm.atext_completion(model={model_name})\033[32m 200 OK\033[0m"
|
||||
|
@ -1063,6 +1139,10 @@ class Router:
|
|||
model_client = potential_model_client
|
||||
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
### DEPLOYMENT-SPECIFIC PRE-CALL CHECKS ### (e.g. update rpm pre-call. Raise error, if deployment over limit)
|
||||
self.routing_strategy_pre_call_checks(deployment=deployment)
|
||||
|
||||
response = litellm.embedding(
|
||||
**{
|
||||
**data,
|
||||
|
@ -1147,7 +1227,7 @@ class Router:
|
|||
model_client = potential_model_client
|
||||
|
||||
self.total_calls[model_name] += 1
|
||||
response = await litellm.aembedding(
|
||||
response = litellm.aembedding(
|
||||
**{
|
||||
**data,
|
||||
"input": input,
|
||||
|
@ -1156,6 +1236,28 @@ class Router:
|
|||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
### CONCURRENCY-SAFE RPM CHECKS ###
|
||||
rpm_semaphore = self._get_client(
|
||||
deployment=deployment, kwargs=kwargs, client_type="rpm_client"
|
||||
)
|
||||
|
||||
if rpm_semaphore is not None and isinstance(
|
||||
rpm_semaphore, asyncio.Semaphore
|
||||
):
|
||||
async with rpm_semaphore:
|
||||
"""
|
||||
- Check rpm limits before making the call
|
||||
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||
"""
|
||||
await self.async_routing_strategy_pre_call_checks(
|
||||
deployment=deployment
|
||||
)
|
||||
response = await response
|
||||
else:
|
||||
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
|
||||
response = await response
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
verbose_router_logger.info(
|
||||
f"litellm.aembedding(model={model_name})\033[32m 200 OK\033[0m"
|
||||
|
@ -1713,7 +1815,23 @@ class Router:
|
|||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||
return cooldown_models
|
||||
|
||||
async def routing_strategy_pre_call_checks(self, deployment: dict):
|
||||
def routing_strategy_pre_call_checks(self, deployment: dict):
|
||||
"""
|
||||
Mimics 'async_routing_strategy_pre_call_checks'
|
||||
|
||||
Ensures consistent update rpm implementation for 'usage-based-routing-v2'
|
||||
|
||||
Returns:
|
||||
- None
|
||||
|
||||
Raises:
|
||||
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
|
||||
"""
|
||||
for _callback in litellm.callbacks:
|
||||
if isinstance(_callback, CustomLogger):
|
||||
response = _callback.pre_call_check(deployment)
|
||||
|
||||
async def async_routing_strategy_pre_call_checks(self, deployment: dict):
|
||||
"""
|
||||
For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue