mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Litellm dev readd prompt caching (#7299)
* fix(router.py): re-add saving model id on prompt caching valid successful deployment * fix(router.py): introduce optional pre_call_checks isolate prompt caching logic in a separate file * fix(prompt_caching_deployment_check.py): fix import * fix(router.py): new 'async_filter_deployments' event hook allows custom logger to filter deployments returned to routing strategy * feat(prompt_caching_deployment_check.py): initial working commit of prompt caching based routing * fix(cooldown_callbacks.py): fix linting error * fix(budget_limiter.py): move budget logger to async_filter_deployment hook * test: add unit test * test(test_router_helper_utils.py): add unit testing * fix(budget_limiter.py): fix linting errors * docs(config_settings.md): add 'optional_pre_call_checks' to router_settings param docs
This commit is contained in:
parent
d214d3cc3f
commit
2f08341a08
12 changed files with 276 additions and 74 deletions
|
@ -285,6 +285,8 @@ router_settings:
|
|||
| redis_url | str | URL for Redis server. **Known performance issue with Redis URL.** |
|
||||
| cache_responses | boolean | Flag to enable caching LLM Responses, if cache set under `router_settings`. If true, caches responses. Defaults to False. |
|
||||
| router_general_settings | RouterGeneralSettings | [SDK-Only] Router general settings - contains optimizations like 'async_only_mode'. [Docs](../routing.md#router-general-settings) |
|
||||
| optional_pre_call_checks | List[str] | List of pre-call checks to add to the router. Currently supported: 'router_budget_limiting', 'prompt_caching' |
|
||||
|
||||
|
||||
### environment variables - Reference
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ from litellm.proxy._types import (
|
|||
LiteLLM_UpperboundKeyGenerateParams,
|
||||
)
|
||||
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import httpx
|
||||
import dotenv
|
||||
from enum import Enum
|
||||
|
@ -75,7 +76,9 @@ logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
|
|||
_known_custom_logger_compatible_callbacks: List = list(
|
||||
get_args(_custom_logger_compatible_callbacks_literal)
|
||||
)
|
||||
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
|
||||
callbacks: List[
|
||||
Union[Callable, _custom_logger_compatible_callbacks_literal, CustomLogger]
|
||||
] = []
|
||||
langfuse_default_tags: Optional[List[str]] = None
|
||||
langsmith_batch_size: Optional[int] = None
|
||||
argilla_batch_size: Optional[int] = None
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
import os
|
||||
import traceback
|
||||
from datetime import datetime as datetimeObj
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import dotenv
|
||||
from pydantic import BaseModel
|
||||
|
@ -11,7 +11,7 @@ from pydantic import BaseModel
|
|||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.integrations.argilla import ArgillaItem
|
||||
from litellm.types.llms.openai import ChatCompletionRequest
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
|
||||
from litellm.types.services import ServiceLoggerPayload
|
||||
from litellm.types.utils import (
|
||||
AdapterCompletionStreamWrapper,
|
||||
|
@ -69,6 +69,16 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks).
|
||||
"""
|
||||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
request_kwargs: Optional[dict] = None,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
) -> List[dict]:
|
||||
return healthy_deployments
|
||||
|
||||
async def async_pre_call_check(
|
||||
self, deployment: dict, parent_otel_span: Optional[Span]
|
||||
) -> Optional[dict]:
|
||||
|
|
|
@ -92,6 +92,9 @@ from litellm.router_utils.handle_error import (
|
|||
async_raise_no_deployment_exception,
|
||||
send_llm_exception_alert,
|
||||
)
|
||||
from litellm.router_utils.pre_call_checks.prompt_caching_deployment_check import (
|
||||
PromptCachingDeploymentCheck,
|
||||
)
|
||||
from litellm.router_utils.router_callbacks.track_deployment_metrics import (
|
||||
increment_deployment_failures_for_current_minute,
|
||||
increment_deployment_successes_for_current_minute,
|
||||
|
@ -128,6 +131,7 @@ from litellm.types.router import (
|
|||
LiteLLMParamsTypedDict,
|
||||
ModelGroupInfo,
|
||||
ModelInfo,
|
||||
OptionalPreCallChecks,
|
||||
RetryPolicy,
|
||||
RouterCacheEnum,
|
||||
RouterErrors,
|
||||
|
@ -153,6 +157,7 @@ from litellm.utils import (
|
|||
get_llm_provider,
|
||||
get_secret,
|
||||
get_utc_datetime,
|
||||
is_prompt_caching_valid_prompt,
|
||||
is_region_allowed,
|
||||
)
|
||||
|
||||
|
@ -248,6 +253,7 @@ class Router:
|
|||
"cost-based-routing",
|
||||
"usage-based-routing-v2",
|
||||
] = "simple-shuffle",
|
||||
optional_pre_call_checks: Optional[OptionalPreCallChecks] = None,
|
||||
routing_strategy_args: dict = {}, # just for latency-based
|
||||
provider_budget_config: Optional[GenericBudgetConfigType] = None,
|
||||
alerting_config: Optional[AlertingConfig] = None,
|
||||
|
@ -542,11 +548,10 @@ class Router:
|
|||
if RouterBudgetLimiting.should_init_router_budget_limiter(
|
||||
model_list=model_list, provider_budget_config=self.provider_budget_config
|
||||
):
|
||||
self.router_budget_logger = RouterBudgetLimiting(
|
||||
router_cache=self.cache,
|
||||
provider_budget_config=self.provider_budget_config,
|
||||
model_list=self.model_list,
|
||||
)
|
||||
if optional_pre_call_checks is not None:
|
||||
optional_pre_call_checks.append("router_budget_limiting")
|
||||
else:
|
||||
optional_pre_call_checks = ["router_budget_limiting"]
|
||||
self.retry_policy: Optional[RetryPolicy] = None
|
||||
if retry_policy is not None:
|
||||
if isinstance(retry_policy, dict):
|
||||
|
@ -577,6 +582,10 @@ class Router:
|
|||
)
|
||||
|
||||
self.alerting_config: Optional[AlertingConfig] = alerting_config
|
||||
|
||||
if optional_pre_call_checks is not None:
|
||||
self.add_optional_pre_call_checks(optional_pre_call_checks)
|
||||
|
||||
if self.alerting_config is not None:
|
||||
self._initialize_alerting()
|
||||
|
||||
|
@ -612,6 +621,23 @@ class Router:
|
|||
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
|
||||
)
|
||||
|
||||
def add_optional_pre_call_checks(
|
||||
self, optional_pre_call_checks: Optional[OptionalPreCallChecks]
|
||||
):
|
||||
if optional_pre_call_checks is not None:
|
||||
for pre_call_check in optional_pre_call_checks:
|
||||
_callback: Optional[CustomLogger] = None
|
||||
if pre_call_check == "prompt_caching":
|
||||
_callback = PromptCachingDeploymentCheck(cache=self.cache)
|
||||
elif pre_call_check == "router_budget_limiting":
|
||||
_callback = RouterBudgetLimiting(
|
||||
router_cache=self.cache,
|
||||
provider_budget_config=self.provider_budget_config,
|
||||
model_list=self.model_list,
|
||||
)
|
||||
if _callback is not None:
|
||||
litellm.callbacks.append(_callback)
|
||||
|
||||
def routing_strategy_init(
|
||||
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
|
||||
):
|
||||
|
@ -1810,7 +1836,6 @@ class Router:
|
|||
|
||||
return await litellm._arealtime(**{**data, "caching": self.cache_responses, **kwargs}) # type: ignore
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if self.num_retries > 0:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
|
@ -3261,6 +3286,7 @@ class Router:
|
|||
litellm_router_instance=self,
|
||||
deployment_id=id,
|
||||
)
|
||||
|
||||
return tpm_key
|
||||
|
||||
except Exception as e:
|
||||
|
@ -3699,6 +3725,57 @@ class Router:
|
|||
).start() # log response
|
||||
raise e
|
||||
|
||||
async def async_callback_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List[dict],
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
parent_otel_span: Optional[Span],
|
||||
request_kwargs: Optional[dict] = None,
|
||||
logging_obj: Optional[LiteLLMLogging] = None,
|
||||
):
|
||||
"""
|
||||
For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
|
||||
|
||||
-> makes the calls concurrency-safe, when rpm limits are set for a deployment
|
||||
|
||||
Returns:
|
||||
- None
|
||||
|
||||
Raises:
|
||||
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
|
||||
"""
|
||||
returned_healthy_deployments = healthy_deployments
|
||||
for _callback in litellm.callbacks:
|
||||
if isinstance(_callback, CustomLogger):
|
||||
try:
|
||||
returned_healthy_deployments = (
|
||||
await _callback.async_filter_deployments(
|
||||
model=model,
|
||||
healthy_deployments=returned_healthy_deployments,
|
||||
messages=messages,
|
||||
request_kwargs=request_kwargs,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
## LOG FAILURE EVENT
|
||||
if logging_obj is not None:
|
||||
asyncio.create_task(
|
||||
logging_obj.async_failure_handler(
|
||||
exception=e,
|
||||
traceback_exception=traceback.format_exc(),
|
||||
end_time=time.time(),
|
||||
)
|
||||
)
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=logging_obj.failure_handler,
|
||||
args=(e, traceback.format_exc()),
|
||||
).start() # log response
|
||||
raise e
|
||||
return returned_healthy_deployments
|
||||
|
||||
def _generate_model_id(self, model_group: str, litellm_params: dict):
|
||||
"""
|
||||
Helper function to consistently generate the same id for a deployment
|
||||
|
@ -5188,10 +5265,22 @@ class Router:
|
|||
cooldown_deployments=cooldown_deployments,
|
||||
)
|
||||
|
||||
healthy_deployments = await self.async_callback_filter_deployments(
|
||||
model=model,
|
||||
healthy_deployments=healthy_deployments,
|
||||
messages=(
|
||||
cast(List[AllMessageValues], messages)
|
||||
if messages is not None
|
||||
else None
|
||||
),
|
||||
request_kwargs=request_kwargs,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
if self.enable_pre_call_checks and messages is not None:
|
||||
healthy_deployments = self._pre_call_checks(
|
||||
model=model,
|
||||
healthy_deployments=healthy_deployments,
|
||||
healthy_deployments=cast(List[Dict], healthy_deployments),
|
||||
messages=messages,
|
||||
request_kwargs=request_kwargs,
|
||||
)
|
||||
|
@ -5203,13 +5292,13 @@ class Router:
|
|||
healthy_deployments=healthy_deployments,
|
||||
)
|
||||
|
||||
if self.router_budget_logger:
|
||||
healthy_deployments = (
|
||||
await self.router_budget_logger.async_filter_deployments(
|
||||
healthy_deployments=healthy_deployments,
|
||||
request_kwargs=request_kwargs,
|
||||
)
|
||||
)
|
||||
# if self.router_budget_logger:
|
||||
# healthy_deployments = (
|
||||
# await self.router_budget_logger.async_filter_deployments(
|
||||
# healthy_deployments=healthy_deployments,
|
||||
# request_kwargs=request_kwargs,
|
||||
# )
|
||||
# )
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
exception = await async_raise_no_deployment_exception(
|
||||
|
|
|
@ -26,13 +26,14 @@ import litellm
|
|||
from litellm._logging import verbose_router_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.redis_cache import RedisPipelineIncrementOperation
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.custom_logger import CustomLogger, Span
|
||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.router_strategy.tag_based_routing import _get_tags_from_request_kwargs
|
||||
from litellm.router_utils.cooldown_callbacks import (
|
||||
_get_prometheus_logger_from_callbacks,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.router import (
|
||||
DeploymentTypedDict,
|
||||
GenericBudgetConfigType,
|
||||
|
@ -42,13 +43,6 @@ from litellm.types.router import (
|
|||
)
|
||||
from litellm.types.utils import BudgetConfig, StandardLoggingPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = _Span
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
DEFAULT_REDIS_SYNC_INTERVAL = 1
|
||||
|
||||
|
||||
|
@ -79,9 +73,12 @@ class RouterBudgetLimiting(CustomLogger):
|
|||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
healthy_deployments: Union[List[Dict[str, Any]], Dict[str, Any]],
|
||||
request_kwargs: Optional[Dict] = None,
|
||||
):
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
request_kwargs: Optional[dict] = None,
|
||||
parent_otel_span: Optional[Span] = None, # type: ignore
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Filter out deployments that have exceeded their provider budget limit.
|
||||
|
||||
|
@ -102,11 +99,6 @@ class RouterBudgetLimiting(CustomLogger):
|
|||
|
||||
potential_deployments: List[Dict] = []
|
||||
|
||||
# Extract the parent OpenTelemetry span for tracing
|
||||
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
|
||||
request_kwargs
|
||||
)
|
||||
|
||||
cache_keys, provider_configs, deployment_configs = (
|
||||
await self._async_get_cache_keys_for_router_budget_limiting(
|
||||
healthy_deployments=healthy_deployments,
|
||||
|
|
|
@ -91,8 +91,8 @@ def _get_prometheus_logger_from_callbacks() -> Optional[PrometheusLogger]:
|
|||
for _callback in litellm._async_success_callback:
|
||||
if isinstance(_callback, PrometheusLogger):
|
||||
return _callback
|
||||
for _callback in litellm.callbacks:
|
||||
if isinstance(_callback, PrometheusLogger):
|
||||
return _callback
|
||||
for global_callback in litellm.callbacks:
|
||||
if isinstance(global_callback, PrometheusLogger):
|
||||
return global_callback
|
||||
|
||||
return None
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
"""
|
||||
Check if prompt caching is valid for a given deployment
|
||||
|
||||
Route to previously cached model id, if valid
|
||||
"""
|
||||
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching.dual_cache import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger, Span
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import CallTypes, StandardLoggingPayload
|
||||
from litellm.utils import is_prompt_caching_valid_prompt
|
||||
|
||||
from ..prompt_caching_cache import PromptCachingCache
|
||||
|
||||
|
||||
class PromptCachingDeploymentCheck(CustomLogger):
|
||||
def __init__(self, cache: DualCache):
|
||||
self.cache = cache
|
||||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
request_kwargs: Optional[dict] = None,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
) -> List[dict]:
|
||||
if messages is not None and is_prompt_caching_valid_prompt(
|
||||
messages=messages,
|
||||
model=model,
|
||||
): # prompt > 1024 tokens
|
||||
prompt_cache = PromptCachingCache(
|
||||
cache=self.cache,
|
||||
)
|
||||
|
||||
model_id_dict = await prompt_cache.async_get_model_id(
|
||||
messages=cast(List[AllMessageValues], messages),
|
||||
tools=None,
|
||||
)
|
||||
if model_id_dict is not None:
|
||||
model_id = model_id_dict["model_id"]
|
||||
for deployment in healthy_deployments:
|
||||
if deployment["model_info"]["id"] == model_id:
|
||||
return [deployment]
|
||||
|
||||
return healthy_deployments
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
if standard_logging_object is None:
|
||||
return
|
||||
|
||||
call_type = standard_logging_object["call_type"]
|
||||
|
||||
if (
|
||||
call_type != CallTypes.completion.value
|
||||
and call_type != CallTypes.acompletion.value
|
||||
): # only use prompt caching for completion calls
|
||||
verbose_logger.debug(
|
||||
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, CALL TYPE IS NOT COMPLETION"
|
||||
)
|
||||
return
|
||||
|
||||
model = standard_logging_object["model"]
|
||||
messages = standard_logging_object["messages"]
|
||||
model_id = standard_logging_object["model_id"]
|
||||
|
||||
if messages is None or not isinstance(messages, list):
|
||||
verbose_logger.debug(
|
||||
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MESSAGES IS NOT A LIST"
|
||||
)
|
||||
return
|
||||
if model_id is None:
|
||||
verbose_logger.debug(
|
||||
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MODEL ID IS NONE"
|
||||
)
|
||||
return
|
||||
|
||||
## PROMPT CACHING - cache model id, if prompt caching valid prompt + provider
|
||||
if is_prompt_caching_valid_prompt(
|
||||
model=model,
|
||||
messages=cast(List[AllMessageValues], messages),
|
||||
):
|
||||
cache = PromptCachingCache(
|
||||
cache=self.cache,
|
||||
)
|
||||
await cache.async_add_model_id(
|
||||
model_id=model_id,
|
||||
messages=messages,
|
||||
tools=None, # [TODO]: add tools once standard_logging_object supports it
|
||||
)
|
||||
|
||||
return
|
|
@ -172,22 +172,3 @@ class PromptCachingCache:
|
|||
|
||||
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||
return self.cache.get_cache(cache_key)
|
||||
|
||||
async def async_get_prompt_caching_deployment(
|
||||
self,
|
||||
router: litellm_router,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> Optional[dict]:
|
||||
model_id_dict = await self.async_get_model_id(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
if model_id_dict is not None:
|
||||
healthy_deployment_pydantic_obj = router.get_deployment(
|
||||
model_id=model_id_dict["model_id"]
|
||||
)
|
||||
if healthy_deployment_pydantic_obj is not None:
|
||||
return healthy_deployment_pydantic_obj.model_dump(exclude_none=True)
|
||||
return None
|
||||
|
|
|
@ -667,3 +667,6 @@ class GenericBudgetWindowDetails(BaseModel):
|
|||
spend_key: str
|
||||
start_time_key: str
|
||||
ttl_seconds: int
|
||||
|
||||
|
||||
OptionalPreCallChecks = List[Literal["prompt_caching", "router_budget_limiting"]]
|
||||
|
|
|
@ -653,9 +653,9 @@ async def test_router_prompt_caching_model_stored(
|
|||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.skip(
|
||||
reason="BETA FEATURE - skipping since this led to a latency impact, beta feature that is not used as yet"
|
||||
)
|
||||
# @pytest.mark.skip(
|
||||
# reason="BETA FEATURE - skipping since this led to a latency impact, beta feature that is not used as yet"
|
||||
# )
|
||||
async def test_router_with_prompt_caching(anthropic_messages):
|
||||
"""
|
||||
if prompt caching supported model called with prompt caching valid prompt,
|
||||
|
@ -672,15 +672,18 @@ async def test_router_with_prompt_caching(anthropic_messages):
|
|||
"litellm_params": {
|
||||
"model": "anthropic/claude-3-5-sonnet-20240620",
|
||||
"api_key": os.environ.get("ANTHROPIC_API_KEY"),
|
||||
"mock_response": "The sky is blue.",
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "claude-model",
|
||||
"litellm_params": {
|
||||
"model": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"mock_response": "The sky is green.",
|
||||
},
|
||||
},
|
||||
]
|
||||
],
|
||||
optional_pre_call_checks=["prompt_caching"],
|
||||
)
|
||||
|
||||
response = await router.acompletion(
|
||||
|
@ -699,6 +702,7 @@ async def test_router_with_prompt_caching(anthropic_messages):
|
|||
|
||||
cached_model_id = cache.get_model_id(messages=anthropic_messages, tools=None)
|
||||
|
||||
assert cached_model_id is not None
|
||||
prompt_caching_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
|
||||
messages=anthropic_messages, tools=None
|
||||
)
|
||||
|
@ -709,18 +713,12 @@ async def test_router_with_prompt_caching(anthropic_messages):
|
|||
{"role": "user", "content": "What is the weather in SF?"}
|
||||
]
|
||||
|
||||
pc_deployment = await cache.async_get_prompt_caching_deployment(
|
||||
router=router,
|
||||
messages=new_messages,
|
||||
tools=None,
|
||||
)
|
||||
assert pc_deployment is not None
|
||||
for _ in range(20):
|
||||
response = await router.acompletion(
|
||||
messages=new_messages,
|
||||
model="claude-model",
|
||||
mock_response="The sky is blue.",
|
||||
)
|
||||
print("response=", response)
|
||||
|
||||
response = await router.acompletion(
|
||||
messages=new_messages,
|
||||
model="claude-model",
|
||||
mock_response="The sky is blue.",
|
||||
)
|
||||
print("response=", response)
|
||||
|
||||
assert response._hidden_params["model_id"] == initial_model_id
|
||||
assert response._hidden_params["model_id"] == initial_model_id
|
||||
|
|
|
@ -304,7 +304,7 @@ async def test_prometheus_metric_tracking():
|
|||
await asyncio.sleep(2.5)
|
||||
|
||||
# Verify the mock was called correctly
|
||||
mock_prometheus.track_provider_remaining_budget.assert_called_once()
|
||||
mock_prometheus.track_provider_remaining_budget.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -1058,3 +1058,28 @@ def test_has_default_fallbacks(model_list, has_default_fallbacks, expected_resul
|
|||
),
|
||||
)
|
||||
assert router._has_default_fallbacks() is expected_result
|
||||
|
||||
|
||||
def test_add_optional_pre_call_checks(model_list):
|
||||
router = Router(model_list=model_list)
|
||||
|
||||
router.add_optional_pre_call_checks(["prompt_caching"])
|
||||
assert len(litellm.callbacks) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_callback_filter_deployments(model_list):
|
||||
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
|
||||
|
||||
router = Router(model_list=model_list)
|
||||
|
||||
healthy_deployments = router.get_model_list(model_name="gpt-3.5-turbo")
|
||||
|
||||
new_healthy_deployments = await router.async_callback_filter_deployments(
|
||||
model="gpt-3.5-turbo",
|
||||
healthy_deployments=healthy_deployments,
|
||||
messages=[],
|
||||
parent_otel_span=None,
|
||||
)
|
||||
|
||||
assert len(new_healthy_deployments) == len(healthy_deployments)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue