diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index 17fe69bad2..5559592407 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -285,6 +285,8 @@ router_settings: | redis_url | str | URL for Redis server. **Known performance issue with Redis URL.** | | cache_responses | boolean | Flag to enable caching LLM Responses, if cache set under `router_settings`. If true, caches responses. Defaults to False. | | router_general_settings | RouterGeneralSettings | [SDK-Only] Router general settings - contains optimizations like 'async_only_mode'. [Docs](../routing.md#router-general-settings) | +| optional_pre_call_checks | List[str] | List of pre-call checks to add to the router. Currently supported: 'router_budget_limiting', 'prompt_caching' | + ### environment variables - Reference diff --git a/litellm/__init__.py b/litellm/__init__.py index c1b0d86d12..576239b946 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -34,6 +34,7 @@ from litellm.proxy._types import ( LiteLLM_UpperboundKeyGenerateParams, ) from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders +from litellm.integrations.custom_logger import CustomLogger import httpx import dotenv from enum import Enum @@ -75,7 +76,9 @@ logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None _known_custom_logger_compatible_callbacks: List = list( get_args(_custom_logger_compatible_callbacks_literal) ) -callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = [] +callbacks: List[ + Union[Callable, _custom_logger_compatible_callbacks_literal, CustomLogger] +] = [] langfuse_default_tags: Optional[List[str]] = None langsmith_batch_size: Optional[int] = None argilla_batch_size: Optional[int] = None diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index d62bd3e4db..b714936920 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -3,7 +3,7 @@ import os import traceback from datetime import datetime as datetimeObj -from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union import dotenv from pydantic import BaseModel @@ -11,7 +11,7 @@ from pydantic import BaseModel from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.types.integrations.argilla import ArgillaItem -from litellm.types.llms.openai import ChatCompletionRequest +from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest from litellm.types.services import ServiceLoggerPayload from litellm.types.utils import ( AdapterCompletionStreamWrapper, @@ -69,6 +69,16 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks). """ + async def async_filter_deployments( + self, + model: str, + healthy_deployments: List, + messages: Optional[List[AllMessageValues]], + request_kwargs: Optional[dict] = None, + parent_otel_span: Optional[Span] = None, + ) -> List[dict]: + return healthy_deployments + async def async_pre_call_check( self, deployment: dict, parent_otel_span: Optional[Span] ) -> Optional[dict]: diff --git a/litellm/router.py b/litellm/router.py index 368be65775..9e1fb7d9f4 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -92,6 +92,9 @@ from litellm.router_utils.handle_error import ( async_raise_no_deployment_exception, send_llm_exception_alert, ) +from litellm.router_utils.pre_call_checks.prompt_caching_deployment_check import ( + PromptCachingDeploymentCheck, +) from litellm.router_utils.router_callbacks.track_deployment_metrics import ( increment_deployment_failures_for_current_minute, increment_deployment_successes_for_current_minute, @@ -128,6 +131,7 @@ from litellm.types.router import ( LiteLLMParamsTypedDict, ModelGroupInfo, ModelInfo, + OptionalPreCallChecks, RetryPolicy, RouterCacheEnum, RouterErrors, @@ -153,6 +157,7 @@ from litellm.utils import ( get_llm_provider, get_secret, get_utc_datetime, + is_prompt_caching_valid_prompt, is_region_allowed, ) @@ -248,6 +253,7 @@ class Router: "cost-based-routing", "usage-based-routing-v2", ] = "simple-shuffle", + optional_pre_call_checks: Optional[OptionalPreCallChecks] = None, routing_strategy_args: dict = {}, # just for latency-based provider_budget_config: Optional[GenericBudgetConfigType] = None, alerting_config: Optional[AlertingConfig] = None, @@ -542,11 +548,10 @@ class Router: if RouterBudgetLimiting.should_init_router_budget_limiter( model_list=model_list, provider_budget_config=self.provider_budget_config ): - self.router_budget_logger = RouterBudgetLimiting( - router_cache=self.cache, - provider_budget_config=self.provider_budget_config, - model_list=self.model_list, - ) + if optional_pre_call_checks is not None: + optional_pre_call_checks.append("router_budget_limiting") + else: + optional_pre_call_checks = ["router_budget_limiting"] self.retry_policy: Optional[RetryPolicy] = None if retry_policy is not None: if isinstance(retry_policy, dict): @@ -577,6 +582,10 @@ class Router: ) self.alerting_config: Optional[AlertingConfig] = alerting_config + + if optional_pre_call_checks is not None: + self.add_optional_pre_call_checks(optional_pre_call_checks) + if self.alerting_config is not None: self._initialize_alerting() @@ -612,6 +621,23 @@ class Router: f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys." ) + def add_optional_pre_call_checks( + self, optional_pre_call_checks: Optional[OptionalPreCallChecks] + ): + if optional_pre_call_checks is not None: + for pre_call_check in optional_pre_call_checks: + _callback: Optional[CustomLogger] = None + if pre_call_check == "prompt_caching": + _callback = PromptCachingDeploymentCheck(cache=self.cache) + elif pre_call_check == "router_budget_limiting": + _callback = RouterBudgetLimiting( + router_cache=self.cache, + provider_budget_config=self.provider_budget_config, + model_list=self.model_list, + ) + if _callback is not None: + litellm.callbacks.append(_callback) + def routing_strategy_init( self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict ): @@ -1810,7 +1836,6 @@ class Router: return await litellm._arealtime(**{**data, "caching": self.cache_responses, **kwargs}) # type: ignore except Exception as e: - traceback.print_exc() if self.num_retries > 0: kwargs["model"] = model kwargs["messages"] = messages @@ -3261,6 +3286,7 @@ class Router: litellm_router_instance=self, deployment_id=id, ) + return tpm_key except Exception as e: @@ -3699,6 +3725,57 @@ class Router: ).start() # log response raise e + async def async_callback_filter_deployments( + self, + model: str, + healthy_deployments: List[dict], + messages: Optional[List[AllMessageValues]], + parent_otel_span: Optional[Span], + request_kwargs: Optional[dict] = None, + logging_obj: Optional[LiteLLMLogging] = None, + ): + """ + For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore. + + -> makes the calls concurrency-safe, when rpm limits are set for a deployment + + Returns: + - None + + Raises: + - Rate Limit Exception - If the deployment is over it's tpm/rpm limits + """ + returned_healthy_deployments = healthy_deployments + for _callback in litellm.callbacks: + if isinstance(_callback, CustomLogger): + try: + returned_healthy_deployments = ( + await _callback.async_filter_deployments( + model=model, + healthy_deployments=returned_healthy_deployments, + messages=messages, + request_kwargs=request_kwargs, + parent_otel_span=parent_otel_span, + ) + ) + except Exception as e: + ## LOG FAILURE EVENT + if logging_obj is not None: + asyncio.create_task( + logging_obj.async_failure_handler( + exception=e, + traceback_exception=traceback.format_exc(), + end_time=time.time(), + ) + ) + ## LOGGING + threading.Thread( + target=logging_obj.failure_handler, + args=(e, traceback.format_exc()), + ).start() # log response + raise e + return returned_healthy_deployments + def _generate_model_id(self, model_group: str, litellm_params: dict): """ Helper function to consistently generate the same id for a deployment @@ -5188,10 +5265,22 @@ class Router: cooldown_deployments=cooldown_deployments, ) + healthy_deployments = await self.async_callback_filter_deployments( + model=model, + healthy_deployments=healthy_deployments, + messages=( + cast(List[AllMessageValues], messages) + if messages is not None + else None + ), + request_kwargs=request_kwargs, + parent_otel_span=parent_otel_span, + ) + if self.enable_pre_call_checks and messages is not None: healthy_deployments = self._pre_call_checks( model=model, - healthy_deployments=healthy_deployments, + healthy_deployments=cast(List[Dict], healthy_deployments), messages=messages, request_kwargs=request_kwargs, ) @@ -5203,13 +5292,13 @@ class Router: healthy_deployments=healthy_deployments, ) - if self.router_budget_logger: - healthy_deployments = ( - await self.router_budget_logger.async_filter_deployments( - healthy_deployments=healthy_deployments, - request_kwargs=request_kwargs, - ) - ) + # if self.router_budget_logger: + # healthy_deployments = ( + # await self.router_budget_logger.async_filter_deployments( + # healthy_deployments=healthy_deployments, + # request_kwargs=request_kwargs, + # ) + # ) if len(healthy_deployments) == 0: exception = await async_raise_no_deployment_exception( diff --git a/litellm/router_strategy/budget_limiter.py b/litellm/router_strategy/budget_limiter.py index a11c3cf09d..8e4d675750 100644 --- a/litellm/router_strategy/budget_limiter.py +++ b/litellm/router_strategy/budget_limiter.py @@ -26,13 +26,14 @@ import litellm from litellm._logging import verbose_router_logger from litellm.caching.caching import DualCache from litellm.caching.redis_cache import RedisPipelineIncrementOperation -from litellm.integrations.custom_logger import CustomLogger +from litellm.integrations.custom_logger import CustomLogger, Span from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm.router_strategy.tag_based_routing import _get_tags_from_request_kwargs from litellm.router_utils.cooldown_callbacks import ( _get_prometheus_logger_from_callbacks, ) +from litellm.types.llms.openai import AllMessageValues from litellm.types.router import ( DeploymentTypedDict, GenericBudgetConfigType, @@ -42,13 +43,6 @@ from litellm.types.router import ( ) from litellm.types.utils import BudgetConfig, StandardLoggingPayload -if TYPE_CHECKING: - from opentelemetry.trace import Span as _Span - - Span = _Span -else: - Span = Any - DEFAULT_REDIS_SYNC_INTERVAL = 1 @@ -79,9 +73,12 @@ class RouterBudgetLimiting(CustomLogger): async def async_filter_deployments( self, - healthy_deployments: Union[List[Dict[str, Any]], Dict[str, Any]], - request_kwargs: Optional[Dict] = None, - ): + model: str, + healthy_deployments: List, + messages: Optional[List[AllMessageValues]], + request_kwargs: Optional[dict] = None, + parent_otel_span: Optional[Span] = None, # type: ignore + ) -> List[dict]: """ Filter out deployments that have exceeded their provider budget limit. @@ -102,11 +99,6 @@ class RouterBudgetLimiting(CustomLogger): potential_deployments: List[Dict] = [] - # Extract the parent OpenTelemetry span for tracing - parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs( - request_kwargs - ) - cache_keys, provider_configs, deployment_configs = ( await self._async_get_cache_keys_for_router_budget_limiting( healthy_deployments=healthy_deployments, diff --git a/litellm/router_utils/cooldown_callbacks.py b/litellm/router_utils/cooldown_callbacks.py index f6465d1358..54a016d3ec 100644 --- a/litellm/router_utils/cooldown_callbacks.py +++ b/litellm/router_utils/cooldown_callbacks.py @@ -91,8 +91,8 @@ def _get_prometheus_logger_from_callbacks() -> Optional[PrometheusLogger]: for _callback in litellm._async_success_callback: if isinstance(_callback, PrometheusLogger): return _callback - for _callback in litellm.callbacks: - if isinstance(_callback, PrometheusLogger): - return _callback + for global_callback in litellm.callbacks: + if isinstance(global_callback, PrometheusLogger): + return global_callback return None diff --git a/litellm/router_utils/pre_call_checks/prompt_caching_deployment_check.py b/litellm/router_utils/pre_call_checks/prompt_caching_deployment_check.py new file mode 100644 index 0000000000..d3d237d9f2 --- /dev/null +++ b/litellm/router_utils/pre_call_checks/prompt_caching_deployment_check.py @@ -0,0 +1,99 @@ +""" +Check if prompt caching is valid for a given deployment + +Route to previously cached model id, if valid +""" + +from typing import List, Optional, cast + +from litellm import verbose_logger +from litellm.caching.dual_cache import DualCache +from litellm.integrations.custom_logger import CustomLogger, Span +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import CallTypes, StandardLoggingPayload +from litellm.utils import is_prompt_caching_valid_prompt + +from ..prompt_caching_cache import PromptCachingCache + + +class PromptCachingDeploymentCheck(CustomLogger): + def __init__(self, cache: DualCache): + self.cache = cache + + async def async_filter_deployments( + self, + model: str, + healthy_deployments: List, + messages: Optional[List[AllMessageValues]], + request_kwargs: Optional[dict] = None, + parent_otel_span: Optional[Span] = None, + ) -> List[dict]: + if messages is not None and is_prompt_caching_valid_prompt( + messages=messages, + model=model, + ): # prompt > 1024 tokens + prompt_cache = PromptCachingCache( + cache=self.cache, + ) + + model_id_dict = await prompt_cache.async_get_model_id( + messages=cast(List[AllMessageValues], messages), + tools=None, + ) + if model_id_dict is not None: + model_id = model_id_dict["model_id"] + for deployment in healthy_deployments: + if deployment["model_info"]["id"] == model_id: + return [deployment] + + return healthy_deployments + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + + if standard_logging_object is None: + return + + call_type = standard_logging_object["call_type"] + + if ( + call_type != CallTypes.completion.value + and call_type != CallTypes.acompletion.value + ): # only use prompt caching for completion calls + verbose_logger.debug( + "litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, CALL TYPE IS NOT COMPLETION" + ) + return + + model = standard_logging_object["model"] + messages = standard_logging_object["messages"] + model_id = standard_logging_object["model_id"] + + if messages is None or not isinstance(messages, list): + verbose_logger.debug( + "litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MESSAGES IS NOT A LIST" + ) + return + if model_id is None: + verbose_logger.debug( + "litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MODEL ID IS NONE" + ) + return + + ## PROMPT CACHING - cache model id, if prompt caching valid prompt + provider + if is_prompt_caching_valid_prompt( + model=model, + messages=cast(List[AllMessageValues], messages), + ): + cache = PromptCachingCache( + cache=self.cache, + ) + await cache.async_add_model_id( + model_id=model_id, + messages=messages, + tools=None, # [TODO]: add tools once standard_logging_object supports it + ) + + return diff --git a/litellm/router_utils/prompt_caching_cache.py b/litellm/router_utils/prompt_caching_cache.py index d1861dc7c8..61698ac6bc 100644 --- a/litellm/router_utils/prompt_caching_cache.py +++ b/litellm/router_utils/prompt_caching_cache.py @@ -172,22 +172,3 @@ class PromptCachingCache: cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools) return self.cache.get_cache(cache_key) - - async def async_get_prompt_caching_deployment( - self, - router: litellm_router, - messages: Optional[List[AllMessageValues]], - tools: Optional[List[ChatCompletionToolParam]], - ) -> Optional[dict]: - model_id_dict = await self.async_get_model_id( - messages=messages, - tools=tools, - ) - - if model_id_dict is not None: - healthy_deployment_pydantic_obj = router.get_deployment( - model_id=model_id_dict["model_id"] - ) - if healthy_deployment_pydantic_obj is not None: - return healthy_deployment_pydantic_obj.model_dump(exclude_none=True) - return None diff --git a/litellm/types/router.py b/litellm/types/router.py index 31f6fc4ecb..974c7085fc 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -667,3 +667,6 @@ class GenericBudgetWindowDetails(BaseModel): spend_key: str start_time_key: str ttl_seconds: int + + +OptionalPreCallChecks = List[Literal["prompt_caching", "router_budget_limiting"]] diff --git a/tests/local_testing/test_anthropic_prompt_caching.py b/tests/local_testing/test_anthropic_prompt_caching.py index 9bf3803314..3bf94945dc 100644 --- a/tests/local_testing/test_anthropic_prompt_caching.py +++ b/tests/local_testing/test_anthropic_prompt_caching.py @@ -653,9 +653,9 @@ async def test_router_prompt_caching_model_stored( @pytest.mark.asyncio() -@pytest.mark.skip( - reason="BETA FEATURE - skipping since this led to a latency impact, beta feature that is not used as yet" -) +# @pytest.mark.skip( +# reason="BETA FEATURE - skipping since this led to a latency impact, beta feature that is not used as yet" +# ) async def test_router_with_prompt_caching(anthropic_messages): """ if prompt caching supported model called with prompt caching valid prompt, @@ -672,15 +672,18 @@ async def test_router_with_prompt_caching(anthropic_messages): "litellm_params": { "model": "anthropic/claude-3-5-sonnet-20240620", "api_key": os.environ.get("ANTHROPIC_API_KEY"), + "mock_response": "The sky is blue.", }, }, { "model_name": "claude-model", "litellm_params": { "model": "anthropic.claude-3-5-sonnet-20241022-v2:0", + "mock_response": "The sky is green.", }, }, - ] + ], + optional_pre_call_checks=["prompt_caching"], ) response = await router.acompletion( @@ -699,6 +702,7 @@ async def test_router_with_prompt_caching(anthropic_messages): cached_model_id = cache.get_model_id(messages=anthropic_messages, tools=None) + assert cached_model_id is not None prompt_caching_cache_key = PromptCachingCache.get_prompt_caching_cache_key( messages=anthropic_messages, tools=None ) @@ -709,18 +713,12 @@ async def test_router_with_prompt_caching(anthropic_messages): {"role": "user", "content": "What is the weather in SF?"} ] - pc_deployment = await cache.async_get_prompt_caching_deployment( - router=router, - messages=new_messages, - tools=None, - ) - assert pc_deployment is not None + for _ in range(20): + response = await router.acompletion( + messages=new_messages, + model="claude-model", + mock_response="The sky is blue.", + ) + print("response=", response) - response = await router.acompletion( - messages=new_messages, - model="claude-model", - mock_response="The sky is blue.", - ) - print("response=", response) - - assert response._hidden_params["model_id"] == initial_model_id + assert response._hidden_params["model_id"] == initial_model_id diff --git a/tests/local_testing/test_router_budget_limiter.py b/tests/local_testing/test_router_budget_limiter.py index fe25ee6dcc..305db6ccf7 100644 --- a/tests/local_testing/test_router_budget_limiter.py +++ b/tests/local_testing/test_router_budget_limiter.py @@ -304,7 +304,7 @@ async def test_prometheus_metric_tracking(): await asyncio.sleep(2.5) # Verify the mock was called correctly - mock_prometheus.track_provider_remaining_budget.assert_called_once() + mock_prometheus.track_provider_remaining_budget.assert_called() @pytest.mark.asyncio diff --git a/tests/router_unit_tests/test_router_helper_utils.py b/tests/router_unit_tests/test_router_helper_utils.py index 1961296630..2d0c702d58 100644 --- a/tests/router_unit_tests/test_router_helper_utils.py +++ b/tests/router_unit_tests/test_router_helper_utils.py @@ -1058,3 +1058,28 @@ def test_has_default_fallbacks(model_list, has_default_fallbacks, expected_resul ), ) assert router._has_default_fallbacks() is expected_result + + +def test_add_optional_pre_call_checks(model_list): + router = Router(model_list=model_list) + + router.add_optional_pre_call_checks(["prompt_caching"]) + assert len(litellm.callbacks) > 0 + + +@pytest.mark.asyncio +async def test_async_callback_filter_deployments(model_list): + from litellm.router_strategy.budget_limiter import RouterBudgetLimiting + + router = Router(model_list=model_list) + + healthy_deployments = router.get_model_list(model_name="gpt-3.5-turbo") + + new_healthy_deployments = await router.async_callback_filter_deployments( + model="gpt-3.5-turbo", + healthy_deployments=healthy_deployments, + messages=[], + parent_otel_span=None, + ) + + assert len(new_healthy_deployments) == len(healthy_deployments)