From 1a9cf00bb473742f69a85d397e1e610bf17ae95a Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 19 Nov 2024 12:35:58 -0800 Subject: [PATCH] working test_provider_budgets_e2e_test --- litellm/router.py | 25 ++- litellm/router_strategy/provider_budgets.py | 192 +++++++++++++++++++ litellm/types/router.py | 10 +- tests/local_testing/test_provider_budgets.py | 84 ++++++++ 4 files changed, 303 insertions(+), 8 deletions(-) create mode 100644 litellm/router_strategy/provider_budgets.py create mode 100644 tests/local_testing/test_provider_budgets.py diff --git a/litellm/router.py b/litellm/router.py index 97065bc85..d582f614f 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -59,6 +59,7 @@ from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler from litellm.router_strategy.lowest_tpm_rpm import LowestTPMLoggingHandler from litellm.router_strategy.lowest_tpm_rpm_v2 import LowestTPMLoggingHandler_v2 +from litellm.router_strategy.provider_budgets import ProviderBudgetLimiting from litellm.router_strategy.simple_shuffle import simple_shuffle from litellm.router_strategy.tag_based_routing import get_deployments_for_tag from litellm.router_utils.batch_utils import ( @@ -234,8 +235,9 @@ class Router: "latency-based-routing", "cost-based-routing", "usage-based-routing-v2", + "provider-budget-routing", ] = "simple-shuffle", - routing_strategy_args: dict = {}, # just for latency-based routing + routing_strategy_args: dict = {}, # just for latency-based, semaphore: Optional[asyncio.Semaphore] = None, alerting_config: Optional[AlertingConfig] = None, router_general_settings: Optional[ @@ -644,6 +646,16 @@ class Router: ) if isinstance(litellm.callbacks, list): litellm.callbacks.append(self.lowestcost_logger) # type: ignore + elif ( + routing_strategy == RoutingStrategy.PROVIDER_BUDGET_LIMITING.value + or routing_strategy == RoutingStrategy.PROVIDER_BUDGET_LIMITING + ): + self.provider_budget_logger = ProviderBudgetLimiting( + router_cache=self.cache, + provider_budget_config=routing_strategy_args, + ) + if isinstance(litellm.callbacks, list): + litellm.callbacks.append(self.provider_budget_logger) # type: ignore else: pass @@ -5055,6 +5067,7 @@ class Router: and self.routing_strategy != "cost-based-routing" and self.routing_strategy != "latency-based-routing" and self.routing_strategy != "least-busy" + and self.routing_strategy != "provider-budget-routing" ): # prevent regressions for other routing strategies, that don't have async get available deployments implemented. return self.get_available_deployment( model=model, @@ -5170,6 +5183,16 @@ class Router: healthy_deployments=healthy_deployments, # type: ignore ) ) + elif ( + self.routing_strategy == "provider-budget-routing" + and self.provider_budget_logger is not None + ): + deployment = ( + await self.provider_budget_logger.async_get_available_deployments( + model_group=model, + healthy_deployments=healthy_deployments, # type: ignore + ) + ) else: deployment = None if deployment is None: diff --git a/litellm/router_strategy/provider_budgets.py b/litellm/router_strategy/provider_budgets.py new file mode 100644 index 000000000..423bcdd59 --- /dev/null +++ b/litellm/router_strategy/provider_budgets.py @@ -0,0 +1,192 @@ +import random +from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union + +import litellm +from litellm._logging import verbose_router_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs +from litellm.types.router import ( + LiteLLM_Params, + ProviderBudgetConfigType, + ProviderBudgetInfo, +) +from litellm.types.utils import StandardLoggingPayload + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + + +class ProviderSpend(TypedDict, total=False): + """ + Provider spend data + + { + "openai": 300.0, + "anthropic": 100.0 + } + """ + + provider: str + spend: float + + +class ProviderBudgetLimiting(CustomLogger): + def __init__(self, router_cache: DualCache, provider_budget_config: dict): + self.router_cache = router_cache + self.provider_budget_config: ProviderBudgetConfigType = provider_budget_config + verbose_router_logger.debug( + f"Initalized Provider budget config: {self.provider_budget_config}" + ) + + async def async_get_available_deployments( + self, + model_group: str, + healthy_deployments: List[Dict], + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + request_kwargs: Optional[Dict] = None, + ): + """ + Filter list of healthy deployments based on provider budget + """ + potential_deployments: List[Dict] = [] + + parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs( + request_kwargs + ) + + for deployment in healthy_deployments: + provider = self._get_llm_provider_for_deployment(deployment) + budget_config = self._get_budget_config_for_provider(provider) + if budget_config is None: + verbose_router_logger.debug( + f"No budget config found for provider {provider}, skipping" + ) + continue + + budget_limit = budget_config.budget_limit + current_spend: float = ( + await self.router_cache.async_get_cache( + key=f"provider_spend:{provider}:{budget_config.time_period}", + parent_otel_span=parent_otel_span, + ) + or 0.0 + ) + + verbose_router_logger.debug( + f"Current spend for {provider}: {current_spend}, budget limit: {budget_limit}" + ) + + if current_spend >= budget_limit: + verbose_router_logger.debug( + f"Skipping deployment {deployment} for provider {provider} as spend limit exceeded" + ) + continue + + potential_deployments.append(deployment) + # randomly pick one deployment from potential_deployments + if potential_deployments: + return random.choice(potential_deployments) + return None + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + """ + Increment provider spend in DualCache (InMemory + Redis) + """ + verbose_router_logger.debug( + f"in ProviderBudgetLimiting.async_log_success_event" + ) + standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + if standard_logging_payload is None: + raise ValueError("standard_logging_payload is required") + + response_cost: float = standard_logging_payload.get("response_cost", 0) + + custom_llm_provider: str = kwargs.get("litellm_params", {}).get( + "custom_llm_provider", None + ) + if custom_llm_provider is None: + raise ValueError("custom_llm_provider is required") + + budget_config = self._get_budget_config_for_provider(custom_llm_provider) + if budget_config is None: + raise ValueError( + f"No budget config found for provider {custom_llm_provider}, self.provider_budget_config: {self.provider_budget_config}" + ) + + spend_key = f"provider_spend:{custom_llm_provider}:{budget_config.time_period}" + ttl_seconds = self.get_ttl_seconds(budget_config.time_period) + verbose_router_logger.debug( + f"Incrementing spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}" + ) + # Increment the spend in Redis and set TTL + await self.router_cache.async_increment_cache( + key=spend_key, + value=response_cost, + ttl=ttl_seconds, + ) + verbose_router_logger.debug( + f"Incremented spend for {spend_key} by {response_cost}, ttl: {ttl_seconds}" + ) + + def _get_budget_config_for_provider( + self, provider: str + ) -> Optional[ProviderBudgetInfo]: + return self.provider_budget_config.get(provider, None) + + def _get_llm_provider_for_deployment(self, deployment: Dict) -> str: + try: + _litellm_params: LiteLLM_Params = LiteLLM_Params( + **deployment["litellm_params"] + ) + _, custom_llm_provider, _, _ = litellm.get_llm_provider( + model=_litellm_params.model, + litellm_params=_litellm_params, + ) + except Exception as e: + raise e + return custom_llm_provider + + def _get_unique_custom_llm_providers_in_deployments( + self, deployments: List[Dict] + ) -> list: + """ + Get unique custom LLM providers in deployments + """ + unique_providers = set() + for deployment in deployments: + provider = self._get_llm_provider_for_deployment(deployment) + unique_providers.add(provider) + return list(unique_providers) + + def get_ttl_seconds(self, time_period: str) -> int: + """ + Convert time period (e.g., '1d', '30d') to seconds for Redis TTL + """ + if time_period.endswith("d"): + days = int(time_period[:-1]) + return days * 24 * 60 * 60 + raise ValueError(f"Unsupported time period format: {time_period}") + + def get_budget_limit(self, custom_llm_provider: str, time_period: str) -> float: + """ + Fetch the budget limit for a given provider and time period. + This can be fetched from a config or database. + """ + _provider_budget_settings = self.provider_budget_config.get( + custom_llm_provider, None + ) + if _provider_budget_settings is None: + return float("inf") + + verbose_router_logger.debug( + f"Provider budget settings: {_provider_budget_settings}" + ) + return _provider_budget_settings.budget_limit diff --git a/litellm/types/router.py b/litellm/types/router.py index c160a8124..f4d2b39ed 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -631,13 +631,9 @@ class RoutingStrategy(enum.Enum): PROVIDER_BUDGET_LIMITING = "provider-budget-routing" -class DayToBudgetLimit(TypedDict): - day: str +class ProviderBudgetInfo(BaseModel): + time_period: str # e.g., '1d', '30d' budget_limit: float -class ProviderBudgetConfig(TypedDict): - custom_llm_provider: str # The name of the provider (e.g., OpenAI, Azure) - budgets: ( - DayToBudgetLimit # Time periods (e.g., '1d', '30d') mapped to budget limits - ) +ProviderBudgetConfigType = Dict[str, ProviderBudgetInfo] diff --git a/tests/local_testing/test_provider_budgets.py b/tests/local_testing/test_provider_budgets.py new file mode 100644 index 000000000..40630c130 --- /dev/null +++ b/tests/local_testing/test_provider_budgets.py @@ -0,0 +1,84 @@ +import sys, os, asyncio, time, random +from datetime import datetime +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os, copy + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +from litellm import Router +from litellm.router_strategy.provider_budgets import ProviderBudgetLimiting +from litellm.types.router import ( + RoutingStrategy, + ProviderBudgetConfigType, + ProviderBudgetInfo, +) +from litellm.caching.caching import DualCache +import logging +from litellm._logging import verbose_router_logger + +verbose_router_logger.setLevel(logging.DEBUG) + + +@pytest.mark.asyncio +async def test_provider_budgets_e2e_test(): + """ + Expected behavior: + - First request forced to OpenAI + - Hit OpenAI budget limit + - Next 3 requests all go to Azure + + """ + provider_budget_config: ProviderBudgetConfigType = { + "openai": ProviderBudgetInfo(time_period="1d", budget_limit=0.000000000001), + "azure": ProviderBudgetInfo(time_period="1d", budget_limit=100), + } + + router = Router( + routing_strategy="provider-budget-routing", + routing_strategy_args=provider_budget_config, + model_list=[ + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE"), + }, + "model_info": {"id": "azure-model-id"}, + }, + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { + "model": "openai/gpt-4o-mini", + }, + "model_info": {"id": "openai-model-id"}, + }, + ], + ) + + response = await router.acompletion( + messages=[{"role": "user", "content": "Hello, how are you?"}], + model="openai/gpt-4o-mini", + ) + print(response) + + await asyncio.sleep(0.5) + + for _ in range(3): + response = await router.acompletion( + messages=[{"role": "user", "content": "Hello, how are you?"}], + model="gpt-3.5-turbo", + ) + print(response) + + print("response.hidden_params", response._hidden_params) + + await asyncio.sleep(0.5) + + assert response._hidden_params.get("custom_llm_provider") == "azure"