forked from phoenix/litellm-mirror
use async_filter_deployments
This commit is contained in:
parent
50168889be
commit
95f21722a0
2 changed files with 24 additions and 8 deletions
|
@ -522,6 +522,11 @@ class Router:
|
|||
self.service_logger_obj = ServiceLogging()
|
||||
self.routing_strategy_args = routing_strategy_args
|
||||
self.provider_budget_config = provider_budget_config
|
||||
if self.provider_budget_config is not None:
|
||||
self.provider_budget_logger = ProviderBudgetLimiting(
|
||||
router_cache=self.cache,
|
||||
provider_budget_config=self.provider_budget_config,
|
||||
)
|
||||
self.retry_policy: Optional[RetryPolicy] = None
|
||||
if retry_policy is not None:
|
||||
if isinstance(retry_policy, dict):
|
||||
|
@ -5114,6 +5119,14 @@ class Router:
|
|||
healthy_deployments=healthy_deployments,
|
||||
)
|
||||
|
||||
if self.provider_budget_config is not None:
|
||||
healthy_deployments = (
|
||||
await self.provider_budget_logger.async_filter_deployments(
|
||||
healthy_deployments=healthy_deployments,
|
||||
request_kwargs=request_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
exception = await async_raise_no_deployment_exception(
|
||||
litellm_router_instance=self,
|
||||
|
|
|
@ -55,19 +55,23 @@ class ProviderBudgetLimiting(CustomLogger):
|
|||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
healthy_deployments: List[Dict],
|
||||
healthy_deployments: Union[List[Dict[str, Any]], Dict[str, Any]],
|
||||
request_kwargs: Optional[Dict] = None,
|
||||
) -> Optional[Dict]:
|
||||
):
|
||||
"""
|
||||
For all deployments, check their LLM provider budget is less than their budget limit.
|
||||
Filter out deployments that have exceeded their provider budget limit.
|
||||
|
||||
If multiple deployments are available, randomly pick one.
|
||||
|
||||
Example:
|
||||
if deployment = openai/gpt-3.5-turbo
|
||||
check if openai budget limit is exceeded
|
||||
|
||||
and openai spend > openai budget limit
|
||||
then skip this deployment
|
||||
"""
|
||||
|
||||
# If a single deployment is passed, convert it to a list
|
||||
if isinstance(healthy_deployments, dict):
|
||||
healthy_deployments = [healthy_deployments]
|
||||
|
||||
potential_deployments: List[Dict] = []
|
||||
|
||||
# Extract the parent OpenTelemetry span for tracing
|
||||
|
@ -134,8 +138,7 @@ class ProviderBudgetLimiting(CustomLogger):
|
|||
|
||||
potential_deployments.append(deployment)
|
||||
|
||||
# Randomly pick one deployment from potential deployments
|
||||
return random.choice(potential_deployments) if potential_deployments else None
|
||||
return potential_deployments
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue