mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Litellm dev readd prompt caching (#7299)
* fix(router.py): re-add saving model id on prompt caching valid successful deployment * fix(router.py): introduce optional pre_call_checks isolate prompt caching logic in a separate file * fix(prompt_caching_deployment_check.py): fix import * fix(router.py): new 'async_filter_deployments' event hook allows custom logger to filter deployments returned to routing strategy * feat(prompt_caching_deployment_check.py): initial working commit of prompt caching based routing * fix(cooldown_callbacks.py): fix linting error * fix(budget_limiter.py): move budget logger to async_filter_deployment hook * test: add unit test * test(test_router_helper_utils.py): add unit testing * fix(budget_limiter.py): fix linting errors * docs(config_settings.md): add 'optional_pre_call_checks' to router_settings param docs
This commit is contained in:
parent
d214d3cc3f
commit
2f08341a08
12 changed files with 276 additions and 74 deletions
|
@ -92,6 +92,9 @@ from litellm.router_utils.handle_error import (
|
|||
async_raise_no_deployment_exception,
|
||||
send_llm_exception_alert,
|
||||
)
|
||||
from litellm.router_utils.pre_call_checks.prompt_caching_deployment_check import (
|
||||
PromptCachingDeploymentCheck,
|
||||
)
|
||||
from litellm.router_utils.router_callbacks.track_deployment_metrics import (
|
||||
increment_deployment_failures_for_current_minute,
|
||||
increment_deployment_successes_for_current_minute,
|
||||
|
@ -128,6 +131,7 @@ from litellm.types.router import (
|
|||
LiteLLMParamsTypedDict,
|
||||
ModelGroupInfo,
|
||||
ModelInfo,
|
||||
OptionalPreCallChecks,
|
||||
RetryPolicy,
|
||||
RouterCacheEnum,
|
||||
RouterErrors,
|
||||
|
@ -153,6 +157,7 @@ from litellm.utils import (
|
|||
get_llm_provider,
|
||||
get_secret,
|
||||
get_utc_datetime,
|
||||
is_prompt_caching_valid_prompt,
|
||||
is_region_allowed,
|
||||
)
|
||||
|
||||
|
@ -248,6 +253,7 @@ class Router:
|
|||
"cost-based-routing",
|
||||
"usage-based-routing-v2",
|
||||
] = "simple-shuffle",
|
||||
optional_pre_call_checks: Optional[OptionalPreCallChecks] = None,
|
||||
routing_strategy_args: dict = {}, # just for latency-based
|
||||
provider_budget_config: Optional[GenericBudgetConfigType] = None,
|
||||
alerting_config: Optional[AlertingConfig] = None,
|
||||
|
@ -542,11 +548,10 @@ class Router:
|
|||
if RouterBudgetLimiting.should_init_router_budget_limiter(
|
||||
model_list=model_list, provider_budget_config=self.provider_budget_config
|
||||
):
|
||||
self.router_budget_logger = RouterBudgetLimiting(
|
||||
router_cache=self.cache,
|
||||
provider_budget_config=self.provider_budget_config,
|
||||
model_list=self.model_list,
|
||||
)
|
||||
if optional_pre_call_checks is not None:
|
||||
optional_pre_call_checks.append("router_budget_limiting")
|
||||
else:
|
||||
optional_pre_call_checks = ["router_budget_limiting"]
|
||||
self.retry_policy: Optional[RetryPolicy] = None
|
||||
if retry_policy is not None:
|
||||
if isinstance(retry_policy, dict):
|
||||
|
@ -577,6 +582,10 @@ class Router:
|
|||
)
|
||||
|
||||
self.alerting_config: Optional[AlertingConfig] = alerting_config
|
||||
|
||||
if optional_pre_call_checks is not None:
|
||||
self.add_optional_pre_call_checks(optional_pre_call_checks)
|
||||
|
||||
if self.alerting_config is not None:
|
||||
self._initialize_alerting()
|
||||
|
||||
|
@ -612,6 +621,23 @@ class Router:
|
|||
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
|
||||
)
|
||||
|
||||
def add_optional_pre_call_checks(
|
||||
self, optional_pre_call_checks: Optional[OptionalPreCallChecks]
|
||||
):
|
||||
if optional_pre_call_checks is not None:
|
||||
for pre_call_check in optional_pre_call_checks:
|
||||
_callback: Optional[CustomLogger] = None
|
||||
if pre_call_check == "prompt_caching":
|
||||
_callback = PromptCachingDeploymentCheck(cache=self.cache)
|
||||
elif pre_call_check == "router_budget_limiting":
|
||||
_callback = RouterBudgetLimiting(
|
||||
router_cache=self.cache,
|
||||
provider_budget_config=self.provider_budget_config,
|
||||
model_list=self.model_list,
|
||||
)
|
||||
if _callback is not None:
|
||||
litellm.callbacks.append(_callback)
|
||||
|
||||
def routing_strategy_init(
|
||||
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
|
||||
):
|
||||
|
@ -1810,7 +1836,6 @@ class Router:
|
|||
|
||||
return await litellm._arealtime(**{**data, "caching": self.cache_responses, **kwargs}) # type: ignore
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if self.num_retries > 0:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
|
@ -3261,6 +3286,7 @@ class Router:
|
|||
litellm_router_instance=self,
|
||||
deployment_id=id,
|
||||
)
|
||||
|
||||
return tpm_key
|
||||
|
||||
except Exception as e:
|
||||
|
@ -3699,6 +3725,57 @@ class Router:
|
|||
).start() # log response
|
||||
raise e
|
||||
|
||||
async def async_callback_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List[dict],
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
parent_otel_span: Optional[Span],
|
||||
request_kwargs: Optional[dict] = None,
|
||||
logging_obj: Optional[LiteLLMLogging] = None,
|
||||
):
|
||||
"""
|
||||
For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
|
||||
|
||||
-> makes the calls concurrency-safe, when rpm limits are set for a deployment
|
||||
|
||||
Returns:
|
||||
- None
|
||||
|
||||
Raises:
|
||||
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
|
||||
"""
|
||||
returned_healthy_deployments = healthy_deployments
|
||||
for _callback in litellm.callbacks:
|
||||
if isinstance(_callback, CustomLogger):
|
||||
try:
|
||||
returned_healthy_deployments = (
|
||||
await _callback.async_filter_deployments(
|
||||
model=model,
|
||||
healthy_deployments=returned_healthy_deployments,
|
||||
messages=messages,
|
||||
request_kwargs=request_kwargs,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
## LOG FAILURE EVENT
|
||||
if logging_obj is not None:
|
||||
asyncio.create_task(
|
||||
logging_obj.async_failure_handler(
|
||||
exception=e,
|
||||
traceback_exception=traceback.format_exc(),
|
||||
end_time=time.time(),
|
||||
)
|
||||
)
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=logging_obj.failure_handler,
|
||||
args=(e, traceback.format_exc()),
|
||||
).start() # log response
|
||||
raise e
|
||||
return returned_healthy_deployments
|
||||
|
||||
def _generate_model_id(self, model_group: str, litellm_params: dict):
|
||||
"""
|
||||
Helper function to consistently generate the same id for a deployment
|
||||
|
@ -5188,10 +5265,22 @@ class Router:
|
|||
cooldown_deployments=cooldown_deployments,
|
||||
)
|
||||
|
||||
healthy_deployments = await self.async_callback_filter_deployments(
|
||||
model=model,
|
||||
healthy_deployments=healthy_deployments,
|
||||
messages=(
|
||||
cast(List[AllMessageValues], messages)
|
||||
if messages is not None
|
||||
else None
|
||||
),
|
||||
request_kwargs=request_kwargs,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
if self.enable_pre_call_checks and messages is not None:
|
||||
healthy_deployments = self._pre_call_checks(
|
||||
model=model,
|
||||
healthy_deployments=healthy_deployments,
|
||||
healthy_deployments=cast(List[Dict], healthy_deployments),
|
||||
messages=messages,
|
||||
request_kwargs=request_kwargs,
|
||||
)
|
||||
|
@ -5203,13 +5292,13 @@ class Router:
|
|||
healthy_deployments=healthy_deployments,
|
||||
)
|
||||
|
||||
if self.router_budget_logger:
|
||||
healthy_deployments = (
|
||||
await self.router_budget_logger.async_filter_deployments(
|
||||
healthy_deployments=healthy_deployments,
|
||||
request_kwargs=request_kwargs,
|
||||
)
|
||||
)
|
||||
# if self.router_budget_logger:
|
||||
# healthy_deployments = (
|
||||
# await self.router_budget_logger.async_filter_deployments(
|
||||
# healthy_deployments=healthy_deployments,
|
||||
# request_kwargs=request_kwargs,
|
||||
# )
|
||||
# )
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
exception = await async_raise_no_deployment_exception(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue