mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Litellm dev readd prompt caching (#7299)
* fix(router.py): re-add saving model id on prompt caching valid successful deployment * fix(router.py): introduce optional pre_call_checks isolate prompt caching logic in a separate file * fix(prompt_caching_deployment_check.py): fix import * fix(router.py): new 'async_filter_deployments' event hook allows custom logger to filter deployments returned to routing strategy * feat(prompt_caching_deployment_check.py): initial working commit of prompt caching based routing * fix(cooldown_callbacks.py): fix linting error * fix(budget_limiter.py): move budget logger to async_filter_deployment hook * test: add unit test * test(test_router_helper_utils.py): add unit testing * fix(budget_limiter.py): fix linting errors * docs(config_settings.md): add 'optional_pre_call_checks' to router_settings param docs
This commit is contained in:
parent
d214d3cc3f
commit
2f08341a08
12 changed files with 276 additions and 74 deletions
|
@ -285,6 +285,8 @@ router_settings:
|
||||||
| redis_url | str | URL for Redis server. **Known performance issue with Redis URL.** |
|
| redis_url | str | URL for Redis server. **Known performance issue with Redis URL.** |
|
||||||
| cache_responses | boolean | Flag to enable caching LLM Responses, if cache set under `router_settings`. If true, caches responses. Defaults to False. |
|
| cache_responses | boolean | Flag to enable caching LLM Responses, if cache set under `router_settings`. If true, caches responses. Defaults to False. |
|
||||||
| router_general_settings | RouterGeneralSettings | [SDK-Only] Router general settings - contains optimizations like 'async_only_mode'. [Docs](../routing.md#router-general-settings) |
|
| router_general_settings | RouterGeneralSettings | [SDK-Only] Router general settings - contains optimizations like 'async_only_mode'. [Docs](../routing.md#router-general-settings) |
|
||||||
|
| optional_pre_call_checks | List[str] | List of pre-call checks to add to the router. Currently supported: 'router_budget_limiting', 'prompt_caching' |
|
||||||
|
|
||||||
|
|
||||||
### environment variables - Reference
|
### environment variables - Reference
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,7 @@ from litellm.proxy._types import (
|
||||||
LiteLLM_UpperboundKeyGenerateParams,
|
LiteLLM_UpperboundKeyGenerateParams,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
|
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
import httpx
|
import httpx
|
||||||
import dotenv
|
import dotenv
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -75,7 +76,9 @@ logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
|
||||||
_known_custom_logger_compatible_callbacks: List = list(
|
_known_custom_logger_compatible_callbacks: List = list(
|
||||||
get_args(_custom_logger_compatible_callbacks_literal)
|
get_args(_custom_logger_compatible_callbacks_literal)
|
||||||
)
|
)
|
||||||
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
|
callbacks: List[
|
||||||
|
Union[Callable, _custom_logger_compatible_callbacks_literal, CustomLogger]
|
||||||
|
] = []
|
||||||
langfuse_default_tags: Optional[List[str]] = None
|
langfuse_default_tags: Optional[List[str]] = None
|
||||||
langsmith_batch_size: Optional[int] = None
|
langsmith_batch_size: Optional[int] = None
|
||||||
argilla_batch_size: Optional[int] = None
|
argilla_batch_size: Optional[int] = None
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime as datetimeObj
|
from datetime import datetime as datetimeObj
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -11,7 +11,7 @@ from pydantic import BaseModel
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.types.integrations.argilla import ArgillaItem
|
from litellm.types.integrations.argilla import ArgillaItem
|
||||||
from litellm.types.llms.openai import ChatCompletionRequest
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
|
||||||
from litellm.types.services import ServiceLoggerPayload
|
from litellm.types.services import ServiceLoggerPayload
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
AdapterCompletionStreamWrapper,
|
AdapterCompletionStreamWrapper,
|
||||||
|
@ -69,6 +69,16 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks).
|
Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
async def async_filter_deployments(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
healthy_deployments: List,
|
||||||
|
messages: Optional[List[AllMessageValues]],
|
||||||
|
request_kwargs: Optional[dict] = None,
|
||||||
|
parent_otel_span: Optional[Span] = None,
|
||||||
|
) -> List[dict]:
|
||||||
|
return healthy_deployments
|
||||||
|
|
||||||
async def async_pre_call_check(
|
async def async_pre_call_check(
|
||||||
self, deployment: dict, parent_otel_span: Optional[Span]
|
self, deployment: dict, parent_otel_span: Optional[Span]
|
||||||
) -> Optional[dict]:
|
) -> Optional[dict]:
|
||||||
|
|
|
@ -92,6 +92,9 @@ from litellm.router_utils.handle_error import (
|
||||||
async_raise_no_deployment_exception,
|
async_raise_no_deployment_exception,
|
||||||
send_llm_exception_alert,
|
send_llm_exception_alert,
|
||||||
)
|
)
|
||||||
|
from litellm.router_utils.pre_call_checks.prompt_caching_deployment_check import (
|
||||||
|
PromptCachingDeploymentCheck,
|
||||||
|
)
|
||||||
from litellm.router_utils.router_callbacks.track_deployment_metrics import (
|
from litellm.router_utils.router_callbacks.track_deployment_metrics import (
|
||||||
increment_deployment_failures_for_current_minute,
|
increment_deployment_failures_for_current_minute,
|
||||||
increment_deployment_successes_for_current_minute,
|
increment_deployment_successes_for_current_minute,
|
||||||
|
@ -128,6 +131,7 @@ from litellm.types.router import (
|
||||||
LiteLLMParamsTypedDict,
|
LiteLLMParamsTypedDict,
|
||||||
ModelGroupInfo,
|
ModelGroupInfo,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
|
OptionalPreCallChecks,
|
||||||
RetryPolicy,
|
RetryPolicy,
|
||||||
RouterCacheEnum,
|
RouterCacheEnum,
|
||||||
RouterErrors,
|
RouterErrors,
|
||||||
|
@ -153,6 +157,7 @@ from litellm.utils import (
|
||||||
get_llm_provider,
|
get_llm_provider,
|
||||||
get_secret,
|
get_secret,
|
||||||
get_utc_datetime,
|
get_utc_datetime,
|
||||||
|
is_prompt_caching_valid_prompt,
|
||||||
is_region_allowed,
|
is_region_allowed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -248,6 +253,7 @@ class Router:
|
||||||
"cost-based-routing",
|
"cost-based-routing",
|
||||||
"usage-based-routing-v2",
|
"usage-based-routing-v2",
|
||||||
] = "simple-shuffle",
|
] = "simple-shuffle",
|
||||||
|
optional_pre_call_checks: Optional[OptionalPreCallChecks] = None,
|
||||||
routing_strategy_args: dict = {}, # just for latency-based
|
routing_strategy_args: dict = {}, # just for latency-based
|
||||||
provider_budget_config: Optional[GenericBudgetConfigType] = None,
|
provider_budget_config: Optional[GenericBudgetConfigType] = None,
|
||||||
alerting_config: Optional[AlertingConfig] = None,
|
alerting_config: Optional[AlertingConfig] = None,
|
||||||
|
@ -542,11 +548,10 @@ class Router:
|
||||||
if RouterBudgetLimiting.should_init_router_budget_limiter(
|
if RouterBudgetLimiting.should_init_router_budget_limiter(
|
||||||
model_list=model_list, provider_budget_config=self.provider_budget_config
|
model_list=model_list, provider_budget_config=self.provider_budget_config
|
||||||
):
|
):
|
||||||
self.router_budget_logger = RouterBudgetLimiting(
|
if optional_pre_call_checks is not None:
|
||||||
router_cache=self.cache,
|
optional_pre_call_checks.append("router_budget_limiting")
|
||||||
provider_budget_config=self.provider_budget_config,
|
else:
|
||||||
model_list=self.model_list,
|
optional_pre_call_checks = ["router_budget_limiting"]
|
||||||
)
|
|
||||||
self.retry_policy: Optional[RetryPolicy] = None
|
self.retry_policy: Optional[RetryPolicy] = None
|
||||||
if retry_policy is not None:
|
if retry_policy is not None:
|
||||||
if isinstance(retry_policy, dict):
|
if isinstance(retry_policy, dict):
|
||||||
|
@ -577,6 +582,10 @@ class Router:
|
||||||
)
|
)
|
||||||
|
|
||||||
self.alerting_config: Optional[AlertingConfig] = alerting_config
|
self.alerting_config: Optional[AlertingConfig] = alerting_config
|
||||||
|
|
||||||
|
if optional_pre_call_checks is not None:
|
||||||
|
self.add_optional_pre_call_checks(optional_pre_call_checks)
|
||||||
|
|
||||||
if self.alerting_config is not None:
|
if self.alerting_config is not None:
|
||||||
self._initialize_alerting()
|
self._initialize_alerting()
|
||||||
|
|
||||||
|
@ -612,6 +621,23 @@ class Router:
|
||||||
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
|
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def add_optional_pre_call_checks(
|
||||||
|
self, optional_pre_call_checks: Optional[OptionalPreCallChecks]
|
||||||
|
):
|
||||||
|
if optional_pre_call_checks is not None:
|
||||||
|
for pre_call_check in optional_pre_call_checks:
|
||||||
|
_callback: Optional[CustomLogger] = None
|
||||||
|
if pre_call_check == "prompt_caching":
|
||||||
|
_callback = PromptCachingDeploymentCheck(cache=self.cache)
|
||||||
|
elif pre_call_check == "router_budget_limiting":
|
||||||
|
_callback = RouterBudgetLimiting(
|
||||||
|
router_cache=self.cache,
|
||||||
|
provider_budget_config=self.provider_budget_config,
|
||||||
|
model_list=self.model_list,
|
||||||
|
)
|
||||||
|
if _callback is not None:
|
||||||
|
litellm.callbacks.append(_callback)
|
||||||
|
|
||||||
def routing_strategy_init(
|
def routing_strategy_init(
|
||||||
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
|
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
|
||||||
):
|
):
|
||||||
|
@ -1810,7 +1836,6 @@ class Router:
|
||||||
|
|
||||||
return await litellm._arealtime(**{**data, "caching": self.cache_responses, **kwargs}) # type: ignore
|
return await litellm._arealtime(**{**data, "caching": self.cache_responses, **kwargs}) # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
|
||||||
if self.num_retries > 0:
|
if self.num_retries > 0:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["messages"] = messages
|
kwargs["messages"] = messages
|
||||||
|
@ -3261,6 +3286,7 @@ class Router:
|
||||||
litellm_router_instance=self,
|
litellm_router_instance=self,
|
||||||
deployment_id=id,
|
deployment_id=id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return tpm_key
|
return tpm_key
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -3699,6 +3725,57 @@ class Router:
|
||||||
).start() # log response
|
).start() # log response
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
async def async_callback_filter_deployments(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
healthy_deployments: List[dict],
|
||||||
|
messages: Optional[List[AllMessageValues]],
|
||||||
|
parent_otel_span: Optional[Span],
|
||||||
|
request_kwargs: Optional[dict] = None,
|
||||||
|
logging_obj: Optional[LiteLLMLogging] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
|
||||||
|
|
||||||
|
-> makes the calls concurrency-safe, when rpm limits are set for a deployment
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
|
||||||
|
"""
|
||||||
|
returned_healthy_deployments = healthy_deployments
|
||||||
|
for _callback in litellm.callbacks:
|
||||||
|
if isinstance(_callback, CustomLogger):
|
||||||
|
try:
|
||||||
|
returned_healthy_deployments = (
|
||||||
|
await _callback.async_filter_deployments(
|
||||||
|
model=model,
|
||||||
|
healthy_deployments=returned_healthy_deployments,
|
||||||
|
messages=messages,
|
||||||
|
request_kwargs=request_kwargs,
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
## LOG FAILURE EVENT
|
||||||
|
if logging_obj is not None:
|
||||||
|
asyncio.create_task(
|
||||||
|
logging_obj.async_failure_handler(
|
||||||
|
exception=e,
|
||||||
|
traceback_exception=traceback.format_exc(),
|
||||||
|
end_time=time.time(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
## LOGGING
|
||||||
|
threading.Thread(
|
||||||
|
target=logging_obj.failure_handler,
|
||||||
|
args=(e, traceback.format_exc()),
|
||||||
|
).start() # log response
|
||||||
|
raise e
|
||||||
|
return returned_healthy_deployments
|
||||||
|
|
||||||
def _generate_model_id(self, model_group: str, litellm_params: dict):
|
def _generate_model_id(self, model_group: str, litellm_params: dict):
|
||||||
"""
|
"""
|
||||||
Helper function to consistently generate the same id for a deployment
|
Helper function to consistently generate the same id for a deployment
|
||||||
|
@ -5188,10 +5265,22 @@ class Router:
|
||||||
cooldown_deployments=cooldown_deployments,
|
cooldown_deployments=cooldown_deployments,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
healthy_deployments = await self.async_callback_filter_deployments(
|
||||||
|
model=model,
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
messages=(
|
||||||
|
cast(List[AllMessageValues], messages)
|
||||||
|
if messages is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
request_kwargs=request_kwargs,
|
||||||
|
parent_otel_span=parent_otel_span,
|
||||||
|
)
|
||||||
|
|
||||||
if self.enable_pre_call_checks and messages is not None:
|
if self.enable_pre_call_checks and messages is not None:
|
||||||
healthy_deployments = self._pre_call_checks(
|
healthy_deployments = self._pre_call_checks(
|
||||||
model=model,
|
model=model,
|
||||||
healthy_deployments=healthy_deployments,
|
healthy_deployments=cast(List[Dict], healthy_deployments),
|
||||||
messages=messages,
|
messages=messages,
|
||||||
request_kwargs=request_kwargs,
|
request_kwargs=request_kwargs,
|
||||||
)
|
)
|
||||||
|
@ -5203,13 +5292,13 @@ class Router:
|
||||||
healthy_deployments=healthy_deployments,
|
healthy_deployments=healthy_deployments,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.router_budget_logger:
|
# if self.router_budget_logger:
|
||||||
healthy_deployments = (
|
# healthy_deployments = (
|
||||||
await self.router_budget_logger.async_filter_deployments(
|
# await self.router_budget_logger.async_filter_deployments(
|
||||||
healthy_deployments=healthy_deployments,
|
# healthy_deployments=healthy_deployments,
|
||||||
request_kwargs=request_kwargs,
|
# request_kwargs=request_kwargs,
|
||||||
)
|
# )
|
||||||
)
|
# )
|
||||||
|
|
||||||
if len(healthy_deployments) == 0:
|
if len(healthy_deployments) == 0:
|
||||||
exception = await async_raise_no_deployment_exception(
|
exception = await async_raise_no_deployment_exception(
|
||||||
|
|
|
@ -26,13 +26,14 @@ import litellm
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
from litellm.caching.redis_cache import RedisPipelineIncrementOperation
|
from litellm.caching.redis_cache import RedisPipelineIncrementOperation
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger, Span
|
||||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||||
from litellm.router_strategy.tag_based_routing import _get_tags_from_request_kwargs
|
from litellm.router_strategy.tag_based_routing import _get_tags_from_request_kwargs
|
||||||
from litellm.router_utils.cooldown_callbacks import (
|
from litellm.router_utils.cooldown_callbacks import (
|
||||||
_get_prometheus_logger_from_callbacks,
|
_get_prometheus_logger_from_callbacks,
|
||||||
)
|
)
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.router import (
|
from litellm.types.router import (
|
||||||
DeploymentTypedDict,
|
DeploymentTypedDict,
|
||||||
GenericBudgetConfigType,
|
GenericBudgetConfigType,
|
||||||
|
@ -42,13 +43,6 @@ from litellm.types.router import (
|
||||||
)
|
)
|
||||||
from litellm.types.utils import BudgetConfig, StandardLoggingPayload
|
from litellm.types.utils import BudgetConfig, StandardLoggingPayload
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from opentelemetry.trace import Span as _Span
|
|
||||||
|
|
||||||
Span = _Span
|
|
||||||
else:
|
|
||||||
Span = Any
|
|
||||||
|
|
||||||
DEFAULT_REDIS_SYNC_INTERVAL = 1
|
DEFAULT_REDIS_SYNC_INTERVAL = 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,9 +73,12 @@ class RouterBudgetLimiting(CustomLogger):
|
||||||
|
|
||||||
async def async_filter_deployments(
|
async def async_filter_deployments(
|
||||||
self,
|
self,
|
||||||
healthy_deployments: Union[List[Dict[str, Any]], Dict[str, Any]],
|
model: str,
|
||||||
request_kwargs: Optional[Dict] = None,
|
healthy_deployments: List,
|
||||||
):
|
messages: Optional[List[AllMessageValues]],
|
||||||
|
request_kwargs: Optional[dict] = None,
|
||||||
|
parent_otel_span: Optional[Span] = None, # type: ignore
|
||||||
|
) -> List[dict]:
|
||||||
"""
|
"""
|
||||||
Filter out deployments that have exceeded their provider budget limit.
|
Filter out deployments that have exceeded their provider budget limit.
|
||||||
|
|
||||||
|
@ -102,11 +99,6 @@ class RouterBudgetLimiting(CustomLogger):
|
||||||
|
|
||||||
potential_deployments: List[Dict] = []
|
potential_deployments: List[Dict] = []
|
||||||
|
|
||||||
# Extract the parent OpenTelemetry span for tracing
|
|
||||||
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
|
|
||||||
request_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
cache_keys, provider_configs, deployment_configs = (
|
cache_keys, provider_configs, deployment_configs = (
|
||||||
await self._async_get_cache_keys_for_router_budget_limiting(
|
await self._async_get_cache_keys_for_router_budget_limiting(
|
||||||
healthy_deployments=healthy_deployments,
|
healthy_deployments=healthy_deployments,
|
||||||
|
|
|
@ -91,8 +91,8 @@ def _get_prometheus_logger_from_callbacks() -> Optional[PrometheusLogger]:
|
||||||
for _callback in litellm._async_success_callback:
|
for _callback in litellm._async_success_callback:
|
||||||
if isinstance(_callback, PrometheusLogger):
|
if isinstance(_callback, PrometheusLogger):
|
||||||
return _callback
|
return _callback
|
||||||
for _callback in litellm.callbacks:
|
for global_callback in litellm.callbacks:
|
||||||
if isinstance(_callback, PrometheusLogger):
|
if isinstance(global_callback, PrometheusLogger):
|
||||||
return _callback
|
return global_callback
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
"""
|
||||||
|
Check if prompt caching is valid for a given deployment
|
||||||
|
|
||||||
|
Route to previously cached model id, if valid
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional, cast
|
||||||
|
|
||||||
|
from litellm import verbose_logger
|
||||||
|
from litellm.caching.dual_cache import DualCache
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger, Span
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import CallTypes, StandardLoggingPayload
|
||||||
|
from litellm.utils import is_prompt_caching_valid_prompt
|
||||||
|
|
||||||
|
from ..prompt_caching_cache import PromptCachingCache
|
||||||
|
|
||||||
|
|
||||||
|
class PromptCachingDeploymentCheck(CustomLogger):
|
||||||
|
def __init__(self, cache: DualCache):
|
||||||
|
self.cache = cache
|
||||||
|
|
||||||
|
async def async_filter_deployments(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
healthy_deployments: List,
|
||||||
|
messages: Optional[List[AllMessageValues]],
|
||||||
|
request_kwargs: Optional[dict] = None,
|
||||||
|
parent_otel_span: Optional[Span] = None,
|
||||||
|
) -> List[dict]:
|
||||||
|
if messages is not None and is_prompt_caching_valid_prompt(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
): # prompt > 1024 tokens
|
||||||
|
prompt_cache = PromptCachingCache(
|
||||||
|
cache=self.cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_id_dict = await prompt_cache.async_get_model_id(
|
||||||
|
messages=cast(List[AllMessageValues], messages),
|
||||||
|
tools=None,
|
||||||
|
)
|
||||||
|
if model_id_dict is not None:
|
||||||
|
model_id = model_id_dict["model_id"]
|
||||||
|
for deployment in healthy_deployments:
|
||||||
|
if deployment["model_info"]["id"] == model_id:
|
||||||
|
return [deployment]
|
||||||
|
|
||||||
|
return healthy_deployments
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||||
|
"standard_logging_object", None
|
||||||
|
)
|
||||||
|
|
||||||
|
if standard_logging_object is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
call_type = standard_logging_object["call_type"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
call_type != CallTypes.completion.value
|
||||||
|
and call_type != CallTypes.acompletion.value
|
||||||
|
): # only use prompt caching for completion calls
|
||||||
|
verbose_logger.debug(
|
||||||
|
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, CALL TYPE IS NOT COMPLETION"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
model = standard_logging_object["model"]
|
||||||
|
messages = standard_logging_object["messages"]
|
||||||
|
model_id = standard_logging_object["model_id"]
|
||||||
|
|
||||||
|
if messages is None or not isinstance(messages, list):
|
||||||
|
verbose_logger.debug(
|
||||||
|
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MESSAGES IS NOT A LIST"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if model_id is None:
|
||||||
|
verbose_logger.debug(
|
||||||
|
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MODEL ID IS NONE"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
## PROMPT CACHING - cache model id, if prompt caching valid prompt + provider
|
||||||
|
if is_prompt_caching_valid_prompt(
|
||||||
|
model=model,
|
||||||
|
messages=cast(List[AllMessageValues], messages),
|
||||||
|
):
|
||||||
|
cache = PromptCachingCache(
|
||||||
|
cache=self.cache,
|
||||||
|
)
|
||||||
|
await cache.async_add_model_id(
|
||||||
|
model_id=model_id,
|
||||||
|
messages=messages,
|
||||||
|
tools=None, # [TODO]: add tools once standard_logging_object supports it
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
|
@ -172,22 +172,3 @@ class PromptCachingCache:
|
||||||
|
|
||||||
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||||
return self.cache.get_cache(cache_key)
|
return self.cache.get_cache(cache_key)
|
||||||
|
|
||||||
async def async_get_prompt_caching_deployment(
|
|
||||||
self,
|
|
||||||
router: litellm_router,
|
|
||||||
messages: Optional[List[AllMessageValues]],
|
|
||||||
tools: Optional[List[ChatCompletionToolParam]],
|
|
||||||
) -> Optional[dict]:
|
|
||||||
model_id_dict = await self.async_get_model_id(
|
|
||||||
messages=messages,
|
|
||||||
tools=tools,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_id_dict is not None:
|
|
||||||
healthy_deployment_pydantic_obj = router.get_deployment(
|
|
||||||
model_id=model_id_dict["model_id"]
|
|
||||||
)
|
|
||||||
if healthy_deployment_pydantic_obj is not None:
|
|
||||||
return healthy_deployment_pydantic_obj.model_dump(exclude_none=True)
|
|
||||||
return None
|
|
||||||
|
|
|
@ -667,3 +667,6 @@ class GenericBudgetWindowDetails(BaseModel):
|
||||||
spend_key: str
|
spend_key: str
|
||||||
start_time_key: str
|
start_time_key: str
|
||||||
ttl_seconds: int
|
ttl_seconds: int
|
||||||
|
|
||||||
|
|
||||||
|
OptionalPreCallChecks = List[Literal["prompt_caching", "router_budget_limiting"]]
|
||||||
|
|
|
@ -653,9 +653,9 @@ async def test_router_prompt_caching_model_stored(
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio()
|
@pytest.mark.asyncio()
|
||||||
@pytest.mark.skip(
|
# @pytest.mark.skip(
|
||||||
reason="BETA FEATURE - skipping since this led to a latency impact, beta feature that is not used as yet"
|
# reason="BETA FEATURE - skipping since this led to a latency impact, beta feature that is not used as yet"
|
||||||
)
|
# )
|
||||||
async def test_router_with_prompt_caching(anthropic_messages):
|
async def test_router_with_prompt_caching(anthropic_messages):
|
||||||
"""
|
"""
|
||||||
if prompt caching supported model called with prompt caching valid prompt,
|
if prompt caching supported model called with prompt caching valid prompt,
|
||||||
|
@ -672,15 +672,18 @@ async def test_router_with_prompt_caching(anthropic_messages):
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"model": "anthropic/claude-3-5-sonnet-20240620",
|
"model": "anthropic/claude-3-5-sonnet-20240620",
|
||||||
"api_key": os.environ.get("ANTHROPIC_API_KEY"),
|
"api_key": os.environ.get("ANTHROPIC_API_KEY"),
|
||||||
|
"mock_response": "The sky is blue.",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model_name": "claude-model",
|
"model_name": "claude-model",
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"model": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
"model": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
"mock_response": "The sky is green.",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]
|
],
|
||||||
|
optional_pre_call_checks=["prompt_caching"],
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await router.acompletion(
|
response = await router.acompletion(
|
||||||
|
@ -699,6 +702,7 @@ async def test_router_with_prompt_caching(anthropic_messages):
|
||||||
|
|
||||||
cached_model_id = cache.get_model_id(messages=anthropic_messages, tools=None)
|
cached_model_id = cache.get_model_id(messages=anthropic_messages, tools=None)
|
||||||
|
|
||||||
|
assert cached_model_id is not None
|
||||||
prompt_caching_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
|
prompt_caching_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
|
||||||
messages=anthropic_messages, tools=None
|
messages=anthropic_messages, tools=None
|
||||||
)
|
)
|
||||||
|
@ -709,18 +713,12 @@ async def test_router_with_prompt_caching(anthropic_messages):
|
||||||
{"role": "user", "content": "What is the weather in SF?"}
|
{"role": "user", "content": "What is the weather in SF?"}
|
||||||
]
|
]
|
||||||
|
|
||||||
pc_deployment = await cache.async_get_prompt_caching_deployment(
|
for _ in range(20):
|
||||||
router=router,
|
response = await router.acompletion(
|
||||||
messages=new_messages,
|
messages=new_messages,
|
||||||
tools=None,
|
model="claude-model",
|
||||||
)
|
mock_response="The sky is blue.",
|
||||||
assert pc_deployment is not None
|
)
|
||||||
|
print("response=", response)
|
||||||
|
|
||||||
response = await router.acompletion(
|
assert response._hidden_params["model_id"] == initial_model_id
|
||||||
messages=new_messages,
|
|
||||||
model="claude-model",
|
|
||||||
mock_response="The sky is blue.",
|
|
||||||
)
|
|
||||||
print("response=", response)
|
|
||||||
|
|
||||||
assert response._hidden_params["model_id"] == initial_model_id
|
|
||||||
|
|
|
@ -304,7 +304,7 @@ async def test_prometheus_metric_tracking():
|
||||||
await asyncio.sleep(2.5)
|
await asyncio.sleep(2.5)
|
||||||
|
|
||||||
# Verify the mock was called correctly
|
# Verify the mock was called correctly
|
||||||
mock_prometheus.track_provider_remaining_budget.assert_called_once()
|
mock_prometheus.track_provider_remaining_budget.assert_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
@ -1058,3 +1058,28 @@ def test_has_default_fallbacks(model_list, has_default_fallbacks, expected_resul
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
assert router._has_default_fallbacks() is expected_result
|
assert router._has_default_fallbacks() is expected_result
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_optional_pre_call_checks(model_list):
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
|
||||||
|
router.add_optional_pre_call_checks(["prompt_caching"])
|
||||||
|
assert len(litellm.callbacks) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_callback_filter_deployments(model_list):
|
||||||
|
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
|
||||||
|
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
|
||||||
|
healthy_deployments = router.get_model_list(model_name="gpt-3.5-turbo")
|
||||||
|
|
||||||
|
new_healthy_deployments = await router.async_callback_filter_deployments(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
messages=[],
|
||||||
|
parent_otel_span=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(new_healthy_deployments) == len(healthy_deployments)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue