Litellm dev readd prompt caching (#7299)

* fix(router.py): re-add saving model id on prompt caching valid successful deployment

* fix(router.py): introduce optional pre_call_checks

isolate prompt caching logic in a separate file

* fix(prompt_caching_deployment_check.py): fix import

* fix(router.py): new 'async_filter_deployments' event hook

allows custom logger to filter deployments returned to routing strategy

* feat(prompt_caching_deployment_check.py): initial working commit of prompt caching based routing

* fix(cooldown_callbacks.py): fix linting error

* fix(budget_limiter.py): move budget logger to async_filter_deployment hook

* test: add unit test

* test(test_router_helper_utils.py): add unit testing

* fix(budget_limiter.py): fix linting errors

* docs(config_settings.md): add 'optional_pre_call_checks' to router_settings param docs
This commit is contained in:
Krish Dholakia 2024-12-18 15:13:49 -08:00 committed by GitHub
parent d214d3cc3f
commit 2f08341a08
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 276 additions and 74 deletions

View file

@ -285,6 +285,8 @@ router_settings:
| redis_url | str | URL for Redis server. **Known performance issue with Redis URL.** | | redis_url | str | URL for Redis server. **Known performance issue with Redis URL.** |
| cache_responses | boolean | Flag to enable caching LLM Responses, if cache set under `router_settings`. If true, caches responses. Defaults to False. | | cache_responses | boolean | Flag to enable caching LLM Responses, if cache set under `router_settings`. If true, caches responses. Defaults to False. |
| router_general_settings | RouterGeneralSettings | [SDK-Only] Router general settings - contains optimizations like 'async_only_mode'. [Docs](../routing.md#router-general-settings) | | router_general_settings | RouterGeneralSettings | [SDK-Only] Router general settings - contains optimizations like 'async_only_mode'. [Docs](../routing.md#router-general-settings) |
| optional_pre_call_checks | List[str] | List of pre-call checks to add to the router. Currently supported: 'router_budget_limiting', 'prompt_caching' |
### environment variables - Reference ### environment variables - Reference

View file

@ -34,6 +34,7 @@ from litellm.proxy._types import (
LiteLLM_UpperboundKeyGenerateParams, LiteLLM_UpperboundKeyGenerateParams,
) )
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
from litellm.integrations.custom_logger import CustomLogger
import httpx import httpx
import dotenv import dotenv
from enum import Enum from enum import Enum
@ -75,7 +76,9 @@ logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
_known_custom_logger_compatible_callbacks: List = list( _known_custom_logger_compatible_callbacks: List = list(
get_args(_custom_logger_compatible_callbacks_literal) get_args(_custom_logger_compatible_callbacks_literal)
) )
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = [] callbacks: List[
Union[Callable, _custom_logger_compatible_callbacks_literal, CustomLogger]
] = []
langfuse_default_tags: Optional[List[str]] = None langfuse_default_tags: Optional[List[str]] = None
langsmith_batch_size: Optional[int] = None langsmith_batch_size: Optional[int] = None
argilla_batch_size: Optional[int] = None argilla_batch_size: Optional[int] = None

View file

@ -3,7 +3,7 @@
import os import os
import traceback import traceback
from datetime import datetime as datetimeObj from datetime import datetime as datetimeObj
from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
import dotenv import dotenv
from pydantic import BaseModel from pydantic import BaseModel
@ -11,7 +11,7 @@ from pydantic import BaseModel
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.integrations.argilla import ArgillaItem from litellm.types.integrations.argilla import ArgillaItem
from litellm.types.llms.openai import ChatCompletionRequest from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
from litellm.types.services import ServiceLoggerPayload from litellm.types.services import ServiceLoggerPayload
from litellm.types.utils import ( from litellm.types.utils import (
AdapterCompletionStreamWrapper, AdapterCompletionStreamWrapper,
@ -69,6 +69,16 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks). Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks).
""" """
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None,
) -> List[dict]:
return healthy_deployments
async def async_pre_call_check( async def async_pre_call_check(
self, deployment: dict, parent_otel_span: Optional[Span] self, deployment: dict, parent_otel_span: Optional[Span]
) -> Optional[dict]: ) -> Optional[dict]:

View file

@ -92,6 +92,9 @@ from litellm.router_utils.handle_error import (
async_raise_no_deployment_exception, async_raise_no_deployment_exception,
send_llm_exception_alert, send_llm_exception_alert,
) )
from litellm.router_utils.pre_call_checks.prompt_caching_deployment_check import (
PromptCachingDeploymentCheck,
)
from litellm.router_utils.router_callbacks.track_deployment_metrics import ( from litellm.router_utils.router_callbacks.track_deployment_metrics import (
increment_deployment_failures_for_current_minute, increment_deployment_failures_for_current_minute,
increment_deployment_successes_for_current_minute, increment_deployment_successes_for_current_minute,
@ -128,6 +131,7 @@ from litellm.types.router import (
LiteLLMParamsTypedDict, LiteLLMParamsTypedDict,
ModelGroupInfo, ModelGroupInfo,
ModelInfo, ModelInfo,
OptionalPreCallChecks,
RetryPolicy, RetryPolicy,
RouterCacheEnum, RouterCacheEnum,
RouterErrors, RouterErrors,
@ -153,6 +157,7 @@ from litellm.utils import (
get_llm_provider, get_llm_provider,
get_secret, get_secret,
get_utc_datetime, get_utc_datetime,
is_prompt_caching_valid_prompt,
is_region_allowed, is_region_allowed,
) )
@ -248,6 +253,7 @@ class Router:
"cost-based-routing", "cost-based-routing",
"usage-based-routing-v2", "usage-based-routing-v2",
] = "simple-shuffle", ] = "simple-shuffle",
optional_pre_call_checks: Optional[OptionalPreCallChecks] = None,
routing_strategy_args: dict = {}, # just for latency-based routing_strategy_args: dict = {}, # just for latency-based
provider_budget_config: Optional[GenericBudgetConfigType] = None, provider_budget_config: Optional[GenericBudgetConfigType] = None,
alerting_config: Optional[AlertingConfig] = None, alerting_config: Optional[AlertingConfig] = None,
@ -542,11 +548,10 @@ class Router:
if RouterBudgetLimiting.should_init_router_budget_limiter( if RouterBudgetLimiting.should_init_router_budget_limiter(
model_list=model_list, provider_budget_config=self.provider_budget_config model_list=model_list, provider_budget_config=self.provider_budget_config
): ):
self.router_budget_logger = RouterBudgetLimiting( if optional_pre_call_checks is not None:
router_cache=self.cache, optional_pre_call_checks.append("router_budget_limiting")
provider_budget_config=self.provider_budget_config, else:
model_list=self.model_list, optional_pre_call_checks = ["router_budget_limiting"]
)
self.retry_policy: Optional[RetryPolicy] = None self.retry_policy: Optional[RetryPolicy] = None
if retry_policy is not None: if retry_policy is not None:
if isinstance(retry_policy, dict): if isinstance(retry_policy, dict):
@ -577,6 +582,10 @@ class Router:
) )
self.alerting_config: Optional[AlertingConfig] = alerting_config self.alerting_config: Optional[AlertingConfig] = alerting_config
if optional_pre_call_checks is not None:
self.add_optional_pre_call_checks(optional_pre_call_checks)
if self.alerting_config is not None: if self.alerting_config is not None:
self._initialize_alerting() self._initialize_alerting()
@ -612,6 +621,23 @@ class Router:
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys." f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
) )
def add_optional_pre_call_checks(
self, optional_pre_call_checks: Optional[OptionalPreCallChecks]
):
if optional_pre_call_checks is not None:
for pre_call_check in optional_pre_call_checks:
_callback: Optional[CustomLogger] = None
if pre_call_check == "prompt_caching":
_callback = PromptCachingDeploymentCheck(cache=self.cache)
elif pre_call_check == "router_budget_limiting":
_callback = RouterBudgetLimiting(
router_cache=self.cache,
provider_budget_config=self.provider_budget_config,
model_list=self.model_list,
)
if _callback is not None:
litellm.callbacks.append(_callback)
def routing_strategy_init( def routing_strategy_init(
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
): ):
@ -1810,7 +1836,6 @@ class Router:
return await litellm._arealtime(**{**data, "caching": self.cache_responses, **kwargs}) # type: ignore return await litellm._arealtime(**{**data, "caching": self.cache_responses, **kwargs}) # type: ignore
except Exception as e: except Exception as e:
traceback.print_exc()
if self.num_retries > 0: if self.num_retries > 0:
kwargs["model"] = model kwargs["model"] = model
kwargs["messages"] = messages kwargs["messages"] = messages
@ -3261,6 +3286,7 @@ class Router:
litellm_router_instance=self, litellm_router_instance=self,
deployment_id=id, deployment_id=id,
) )
return tpm_key return tpm_key
except Exception as e: except Exception as e:
@ -3699,6 +3725,57 @@ class Router:
).start() # log response ).start() # log response
raise e raise e
async def async_callback_filter_deployments(
self,
model: str,
healthy_deployments: List[dict],
messages: Optional[List[AllMessageValues]],
parent_otel_span: Optional[Span],
request_kwargs: Optional[dict] = None,
logging_obj: Optional[LiteLLMLogging] = None,
):
"""
For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
-> makes the calls concurrency-safe, when rpm limits are set for a deployment
Returns:
- None
Raises:
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
"""
returned_healthy_deployments = healthy_deployments
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger):
try:
returned_healthy_deployments = (
await _callback.async_filter_deployments(
model=model,
healthy_deployments=returned_healthy_deployments,
messages=messages,
request_kwargs=request_kwargs,
parent_otel_span=parent_otel_span,
)
)
except Exception as e:
## LOG FAILURE EVENT
if logging_obj is not None:
asyncio.create_task(
logging_obj.async_failure_handler(
exception=e,
traceback_exception=traceback.format_exc(),
end_time=time.time(),
)
)
## LOGGING
threading.Thread(
target=logging_obj.failure_handler,
args=(e, traceback.format_exc()),
).start() # log response
raise e
return returned_healthy_deployments
def _generate_model_id(self, model_group: str, litellm_params: dict): def _generate_model_id(self, model_group: str, litellm_params: dict):
""" """
Helper function to consistently generate the same id for a deployment Helper function to consistently generate the same id for a deployment
@ -5188,10 +5265,22 @@ class Router:
cooldown_deployments=cooldown_deployments, cooldown_deployments=cooldown_deployments,
) )
healthy_deployments = await self.async_callback_filter_deployments(
model=model,
healthy_deployments=healthy_deployments,
messages=(
cast(List[AllMessageValues], messages)
if messages is not None
else None
),
request_kwargs=request_kwargs,
parent_otel_span=parent_otel_span,
)
if self.enable_pre_call_checks and messages is not None: if self.enable_pre_call_checks and messages is not None:
healthy_deployments = self._pre_call_checks( healthy_deployments = self._pre_call_checks(
model=model, model=model,
healthy_deployments=healthy_deployments, healthy_deployments=cast(List[Dict], healthy_deployments),
messages=messages, messages=messages,
request_kwargs=request_kwargs, request_kwargs=request_kwargs,
) )
@ -5203,13 +5292,13 @@ class Router:
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments,
) )
if self.router_budget_logger: # if self.router_budget_logger:
healthy_deployments = ( # healthy_deployments = (
await self.router_budget_logger.async_filter_deployments( # await self.router_budget_logger.async_filter_deployments(
healthy_deployments=healthy_deployments, # healthy_deployments=healthy_deployments,
request_kwargs=request_kwargs, # request_kwargs=request_kwargs,
) # )
) # )
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
exception = await async_raise_no_deployment_exception( exception = await async_raise_no_deployment_exception(

View file

@ -26,13 +26,14 @@ import litellm
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.caching.redis_cache import RedisPipelineIncrementOperation from litellm.caching.redis_cache import RedisPipelineIncrementOperation
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger, Span
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.litellm_core_utils.duration_parser import duration_in_seconds from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.router_strategy.tag_based_routing import _get_tags_from_request_kwargs from litellm.router_strategy.tag_based_routing import _get_tags_from_request_kwargs
from litellm.router_utils.cooldown_callbacks import ( from litellm.router_utils.cooldown_callbacks import (
_get_prometheus_logger_from_callbacks, _get_prometheus_logger_from_callbacks,
) )
from litellm.types.llms.openai import AllMessageValues
from litellm.types.router import ( from litellm.types.router import (
DeploymentTypedDict, DeploymentTypedDict,
GenericBudgetConfigType, GenericBudgetConfigType,
@ -42,13 +43,6 @@ from litellm.types.router import (
) )
from litellm.types.utils import BudgetConfig, StandardLoggingPayload from litellm.types.utils import BudgetConfig, StandardLoggingPayload
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
DEFAULT_REDIS_SYNC_INTERVAL = 1 DEFAULT_REDIS_SYNC_INTERVAL = 1
@ -79,9 +73,12 @@ class RouterBudgetLimiting(CustomLogger):
async def async_filter_deployments( async def async_filter_deployments(
self, self,
healthy_deployments: Union[List[Dict[str, Any]], Dict[str, Any]], model: str,
request_kwargs: Optional[Dict] = None, healthy_deployments: List,
): messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None, # type: ignore
) -> List[dict]:
""" """
Filter out deployments that have exceeded their provider budget limit. Filter out deployments that have exceeded their provider budget limit.
@ -102,11 +99,6 @@ class RouterBudgetLimiting(CustomLogger):
potential_deployments: List[Dict] = [] potential_deployments: List[Dict] = []
# Extract the parent OpenTelemetry span for tracing
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
request_kwargs
)
cache_keys, provider_configs, deployment_configs = ( cache_keys, provider_configs, deployment_configs = (
await self._async_get_cache_keys_for_router_budget_limiting( await self._async_get_cache_keys_for_router_budget_limiting(
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments,

View file

@ -91,8 +91,8 @@ def _get_prometheus_logger_from_callbacks() -> Optional[PrometheusLogger]:
for _callback in litellm._async_success_callback: for _callback in litellm._async_success_callback:
if isinstance(_callback, PrometheusLogger): if isinstance(_callback, PrometheusLogger):
return _callback return _callback
for _callback in litellm.callbacks: for global_callback in litellm.callbacks:
if isinstance(_callback, PrometheusLogger): if isinstance(global_callback, PrometheusLogger):
return _callback return global_callback
return None return None

View file

@ -0,0 +1,99 @@
"""
Check if prompt caching is valid for a given deployment
Route to previously cached model id, if valid
"""
from typing import List, Optional, cast
from litellm import verbose_logger
from litellm.caching.dual_cache import DualCache
from litellm.integrations.custom_logger import CustomLogger, Span
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import CallTypes, StandardLoggingPayload
from litellm.utils import is_prompt_caching_valid_prompt
from ..prompt_caching_cache import PromptCachingCache
class PromptCachingDeploymentCheck(CustomLogger):
def __init__(self, cache: DualCache):
self.cache = cache
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None,
) -> List[dict]:
if messages is not None and is_prompt_caching_valid_prompt(
messages=messages,
model=model,
): # prompt > 1024 tokens
prompt_cache = PromptCachingCache(
cache=self.cache,
)
model_id_dict = await prompt_cache.async_get_model_id(
messages=cast(List[AllMessageValues], messages),
tools=None,
)
if model_id_dict is not None:
model_id = model_id_dict["model_id"]
for deployment in healthy_deployments:
if deployment["model_info"]["id"] == model_id:
return [deployment]
return healthy_deployments
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if standard_logging_object is None:
return
call_type = standard_logging_object["call_type"]
if (
call_type != CallTypes.completion.value
and call_type != CallTypes.acompletion.value
): # only use prompt caching for completion calls
verbose_logger.debug(
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, CALL TYPE IS NOT COMPLETION"
)
return
model = standard_logging_object["model"]
messages = standard_logging_object["messages"]
model_id = standard_logging_object["model_id"]
if messages is None or not isinstance(messages, list):
verbose_logger.debug(
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MESSAGES IS NOT A LIST"
)
return
if model_id is None:
verbose_logger.debug(
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MODEL ID IS NONE"
)
return
## PROMPT CACHING - cache model id, if prompt caching valid prompt + provider
if is_prompt_caching_valid_prompt(
model=model,
messages=cast(List[AllMessageValues], messages),
):
cache = PromptCachingCache(
cache=self.cache,
)
await cache.async_add_model_id(
model_id=model_id,
messages=messages,
tools=None, # [TODO]: add tools once standard_logging_object supports it
)
return

View file

@ -172,22 +172,3 @@ class PromptCachingCache:
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools) cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
return self.cache.get_cache(cache_key) return self.cache.get_cache(cache_key)
async def async_get_prompt_caching_deployment(
self,
router: litellm_router,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[dict]:
model_id_dict = await self.async_get_model_id(
messages=messages,
tools=tools,
)
if model_id_dict is not None:
healthy_deployment_pydantic_obj = router.get_deployment(
model_id=model_id_dict["model_id"]
)
if healthy_deployment_pydantic_obj is not None:
return healthy_deployment_pydantic_obj.model_dump(exclude_none=True)
return None

View file

@ -667,3 +667,6 @@ class GenericBudgetWindowDetails(BaseModel):
spend_key: str spend_key: str
start_time_key: str start_time_key: str
ttl_seconds: int ttl_seconds: int
OptionalPreCallChecks = List[Literal["prompt_caching", "router_budget_limiting"]]

View file

@ -653,9 +653,9 @@ async def test_router_prompt_caching_model_stored(
@pytest.mark.asyncio() @pytest.mark.asyncio()
@pytest.mark.skip( # @pytest.mark.skip(
reason="BETA FEATURE - skipping since this led to a latency impact, beta feature that is not used as yet" # reason="BETA FEATURE - skipping since this led to a latency impact, beta feature that is not used as yet"
) # )
async def test_router_with_prompt_caching(anthropic_messages): async def test_router_with_prompt_caching(anthropic_messages):
""" """
if prompt caching supported model called with prompt caching valid prompt, if prompt caching supported model called with prompt caching valid prompt,
@ -672,15 +672,18 @@ async def test_router_with_prompt_caching(anthropic_messages):
"litellm_params": { "litellm_params": {
"model": "anthropic/claude-3-5-sonnet-20240620", "model": "anthropic/claude-3-5-sonnet-20240620",
"api_key": os.environ.get("ANTHROPIC_API_KEY"), "api_key": os.environ.get("ANTHROPIC_API_KEY"),
"mock_response": "The sky is blue.",
}, },
}, },
{ {
"model_name": "claude-model", "model_name": "claude-model",
"litellm_params": { "litellm_params": {
"model": "anthropic.claude-3-5-sonnet-20241022-v2:0", "model": "anthropic.claude-3-5-sonnet-20241022-v2:0",
"mock_response": "The sky is green.",
}, },
}, },
] ],
optional_pre_call_checks=["prompt_caching"],
) )
response = await router.acompletion( response = await router.acompletion(
@ -699,6 +702,7 @@ async def test_router_with_prompt_caching(anthropic_messages):
cached_model_id = cache.get_model_id(messages=anthropic_messages, tools=None) cached_model_id = cache.get_model_id(messages=anthropic_messages, tools=None)
assert cached_model_id is not None
prompt_caching_cache_key = PromptCachingCache.get_prompt_caching_cache_key( prompt_caching_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
messages=anthropic_messages, tools=None messages=anthropic_messages, tools=None
) )
@ -709,18 +713,12 @@ async def test_router_with_prompt_caching(anthropic_messages):
{"role": "user", "content": "What is the weather in SF?"} {"role": "user", "content": "What is the weather in SF?"}
] ]
pc_deployment = await cache.async_get_prompt_caching_deployment( for _ in range(20):
router=router, response = await router.acompletion(
messages=new_messages, messages=new_messages,
tools=None, model="claude-model",
) mock_response="The sky is blue.",
assert pc_deployment is not None )
print("response=", response)
response = await router.acompletion( assert response._hidden_params["model_id"] == initial_model_id
messages=new_messages,
model="claude-model",
mock_response="The sky is blue.",
)
print("response=", response)
assert response._hidden_params["model_id"] == initial_model_id

View file

@ -304,7 +304,7 @@ async def test_prometheus_metric_tracking():
await asyncio.sleep(2.5) await asyncio.sleep(2.5)
# Verify the mock was called correctly # Verify the mock was called correctly
mock_prometheus.track_provider_remaining_budget.assert_called_once() mock_prometheus.track_provider_remaining_budget.assert_called()
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -1058,3 +1058,28 @@ def test_has_default_fallbacks(model_list, has_default_fallbacks, expected_resul
), ),
) )
assert router._has_default_fallbacks() is expected_result assert router._has_default_fallbacks() is expected_result
def test_add_optional_pre_call_checks(model_list):
router = Router(model_list=model_list)
router.add_optional_pre_call_checks(["prompt_caching"])
assert len(litellm.callbacks) > 0
@pytest.mark.asyncio
async def test_async_callback_filter_deployments(model_list):
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
router = Router(model_list=model_list)
healthy_deployments = router.get_model_list(model_name="gpt-3.5-turbo")
new_healthy_deployments = await router.async_callback_filter_deployments(
model="gpt-3.5-turbo",
healthy_deployments=healthy_deployments,
messages=[],
parent_otel_span=None,
)
assert len(new_healthy_deployments) == len(healthy_deployments)