Litellm dev readd prompt caching (#7299)

* fix(router.py): re-add saving model id on prompt caching valid successful deployment

* fix(router.py): introduce optional pre_call_checks

isolate prompt caching logic in a separate file

* fix(prompt_caching_deployment_check.py): fix import

* fix(router.py): new 'async_filter_deployments' event hook

allows custom logger to filter deployments returned to routing strategy

* feat(prompt_caching_deployment_check.py): initial working commit of prompt caching based routing

* fix(cooldown_callbacks.py): fix linting error

* fix(budget_limiter.py): move budget logger to async_filter_deployment hook

* test: add unit test

* test(test_router_helper_utils.py): add unit testing

* fix(budget_limiter.py): fix linting errors

* docs(config_settings.md): add 'optional_pre_call_checks' to router_settings param docs
This commit is contained in:
Krish Dholakia 2024-12-18 15:13:49 -08:00 committed by GitHub
parent d214d3cc3f
commit 2f08341a08
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 276 additions and 74 deletions

View file

@ -92,6 +92,9 @@ from litellm.router_utils.handle_error import (
async_raise_no_deployment_exception,
send_llm_exception_alert,
)
from litellm.router_utils.pre_call_checks.prompt_caching_deployment_check import (
PromptCachingDeploymentCheck,
)
from litellm.router_utils.router_callbacks.track_deployment_metrics import (
increment_deployment_failures_for_current_minute,
increment_deployment_successes_for_current_minute,
@ -128,6 +131,7 @@ from litellm.types.router import (
LiteLLMParamsTypedDict,
ModelGroupInfo,
ModelInfo,
OptionalPreCallChecks,
RetryPolicy,
RouterCacheEnum,
RouterErrors,
@ -153,6 +157,7 @@ from litellm.utils import (
get_llm_provider,
get_secret,
get_utc_datetime,
is_prompt_caching_valid_prompt,
is_region_allowed,
)
@ -248,6 +253,7 @@ class Router:
"cost-based-routing",
"usage-based-routing-v2",
] = "simple-shuffle",
optional_pre_call_checks: Optional[OptionalPreCallChecks] = None,
routing_strategy_args: dict = {}, # just for latency-based
provider_budget_config: Optional[GenericBudgetConfigType] = None,
alerting_config: Optional[AlertingConfig] = None,
@ -542,11 +548,10 @@ class Router:
if RouterBudgetLimiting.should_init_router_budget_limiter(
model_list=model_list, provider_budget_config=self.provider_budget_config
):
self.router_budget_logger = RouterBudgetLimiting(
router_cache=self.cache,
provider_budget_config=self.provider_budget_config,
model_list=self.model_list,
)
if optional_pre_call_checks is not None:
optional_pre_call_checks.append("router_budget_limiting")
else:
optional_pre_call_checks = ["router_budget_limiting"]
self.retry_policy: Optional[RetryPolicy] = None
if retry_policy is not None:
if isinstance(retry_policy, dict):
@ -577,6 +582,10 @@ class Router:
)
self.alerting_config: Optional[AlertingConfig] = alerting_config
if optional_pre_call_checks is not None:
self.add_optional_pre_call_checks(optional_pre_call_checks)
if self.alerting_config is not None:
self._initialize_alerting()
@ -612,6 +621,23 @@ class Router:
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
)
def add_optional_pre_call_checks(
self, optional_pre_call_checks: Optional[OptionalPreCallChecks]
):
if optional_pre_call_checks is not None:
for pre_call_check in optional_pre_call_checks:
_callback: Optional[CustomLogger] = None
if pre_call_check == "prompt_caching":
_callback = PromptCachingDeploymentCheck(cache=self.cache)
elif pre_call_check == "router_budget_limiting":
_callback = RouterBudgetLimiting(
router_cache=self.cache,
provider_budget_config=self.provider_budget_config,
model_list=self.model_list,
)
if _callback is not None:
litellm.callbacks.append(_callback)
def routing_strategy_init(
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
):
@ -1810,7 +1836,6 @@ class Router:
return await litellm._arealtime(**{**data, "caching": self.cache_responses, **kwargs}) # type: ignore
except Exception as e:
traceback.print_exc()
if self.num_retries > 0:
kwargs["model"] = model
kwargs["messages"] = messages
@ -3261,6 +3286,7 @@ class Router:
litellm_router_instance=self,
deployment_id=id,
)
return tpm_key
except Exception as e:
@ -3699,6 +3725,57 @@ class Router:
).start() # log response
raise e
async def async_callback_filter_deployments(
self,
model: str,
healthy_deployments: List[dict],
messages: Optional[List[AllMessageValues]],
parent_otel_span: Optional[Span],
request_kwargs: Optional[dict] = None,
logging_obj: Optional[LiteLLMLogging] = None,
):
"""
For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
-> makes the calls concurrency-safe, when rpm limits are set for a deployment
Returns:
- None
Raises:
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
"""
returned_healthy_deployments = healthy_deployments
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger):
try:
returned_healthy_deployments = (
await _callback.async_filter_deployments(
model=model,
healthy_deployments=returned_healthy_deployments,
messages=messages,
request_kwargs=request_kwargs,
parent_otel_span=parent_otel_span,
)
)
except Exception as e:
## LOG FAILURE EVENT
if logging_obj is not None:
asyncio.create_task(
logging_obj.async_failure_handler(
exception=e,
traceback_exception=traceback.format_exc(),
end_time=time.time(),
)
)
## LOGGING
threading.Thread(
target=logging_obj.failure_handler,
args=(e, traceback.format_exc()),
).start() # log response
raise e
return returned_healthy_deployments
def _generate_model_id(self, model_group: str, litellm_params: dict):
"""
Helper function to consistently generate the same id for a deployment
@ -5188,10 +5265,22 @@ class Router:
cooldown_deployments=cooldown_deployments,
)
healthy_deployments = await self.async_callback_filter_deployments(
model=model,
healthy_deployments=healthy_deployments,
messages=(
cast(List[AllMessageValues], messages)
if messages is not None
else None
),
request_kwargs=request_kwargs,
parent_otel_span=parent_otel_span,
)
if self.enable_pre_call_checks and messages is not None:
healthy_deployments = self._pre_call_checks(
model=model,
healthy_deployments=healthy_deployments,
healthy_deployments=cast(List[Dict], healthy_deployments),
messages=messages,
request_kwargs=request_kwargs,
)
@ -5203,13 +5292,13 @@ class Router:
healthy_deployments=healthy_deployments,
)
if self.router_budget_logger:
healthy_deployments = (
await self.router_budget_logger.async_filter_deployments(
healthy_deployments=healthy_deployments,
request_kwargs=request_kwargs,
)
)
# if self.router_budget_logger:
# healthy_deployments = (
# await self.router_budget_logger.async_filter_deployments(
# healthy_deployments=healthy_deployments,
# request_kwargs=request_kwargs,
# )
# )
if len(healthy_deployments) == 0:
exception = await async_raise_no_deployment_exception(