Litellm dev readd prompt caching (#7299)

* fix(router.py): re-add saving model id on prompt caching valid successful deployment

* fix(router.py): introduce optional pre_call_checks

isolate prompt caching logic in a separate file

* fix(prompt_caching_deployment_check.py): fix import

* fix(router.py): new 'async_filter_deployments' event hook

allows custom logger to filter deployments returned to routing strategy

* feat(prompt_caching_deployment_check.py): initial working commit of prompt caching based routing

* fix(cooldown_callbacks.py): fix linting error

* fix(budget_limiter.py): move budget logger to async_filter_deployment hook

* test: add unit test

* test(test_router_helper_utils.py): add unit testing

* fix(budget_limiter.py): fix linting errors

* docs(config_settings.md): add 'optional_pre_call_checks' to router_settings param docs
This commit is contained in:
Krish Dholakia 2024-12-18 15:13:49 -08:00 committed by GitHub
parent d214d3cc3f
commit 2f08341a08
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 276 additions and 74 deletions

View file

@ -285,6 +285,8 @@ router_settings:
| redis_url | str | URL for Redis server. **Known performance issue with Redis URL.** |
| cache_responses | boolean | Flag to enable caching LLM Responses, if cache set under `router_settings`. If true, caches responses. Defaults to False. |
| router_general_settings | RouterGeneralSettings | [SDK-Only] Router general settings - contains optimizations like 'async_only_mode'. [Docs](../routing.md#router-general-settings) |
| optional_pre_call_checks | List[str] | List of pre-call checks to add to the router. Currently supported: 'router_budget_limiting', 'prompt_caching' |
### environment variables - Reference

View file

@ -34,6 +34,7 @@ from litellm.proxy._types import (
LiteLLM_UpperboundKeyGenerateParams,
)
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
from litellm.integrations.custom_logger import CustomLogger
import httpx
import dotenv
from enum import Enum
@ -75,7 +76,9 @@ logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
_known_custom_logger_compatible_callbacks: List = list(
get_args(_custom_logger_compatible_callbacks_literal)
)
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
callbacks: List[
Union[Callable, _custom_logger_compatible_callbacks_literal, CustomLogger]
] = []
langfuse_default_tags: Optional[List[str]] = None
langsmith_batch_size: Optional[int] = None
argilla_batch_size: Optional[int] = None

View file

@ -3,7 +3,7 @@
import os
import traceback
from datetime import datetime as datetimeObj
from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
import dotenv
from pydantic import BaseModel
@ -11,7 +11,7 @@ from pydantic import BaseModel
from litellm.caching.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.integrations.argilla import ArgillaItem
from litellm.types.llms.openai import ChatCompletionRequest
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
from litellm.types.services import ServiceLoggerPayload
from litellm.types.utils import (
AdapterCompletionStreamWrapper,
@ -69,6 +69,16 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks).
"""
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None,
) -> List[dict]:
return healthy_deployments
async def async_pre_call_check(
self, deployment: dict, parent_otel_span: Optional[Span]
) -> Optional[dict]:

View file

@ -92,6 +92,9 @@ from litellm.router_utils.handle_error import (
async_raise_no_deployment_exception,
send_llm_exception_alert,
)
from litellm.router_utils.pre_call_checks.prompt_caching_deployment_check import (
PromptCachingDeploymentCheck,
)
from litellm.router_utils.router_callbacks.track_deployment_metrics import (
increment_deployment_failures_for_current_minute,
increment_deployment_successes_for_current_minute,
@ -128,6 +131,7 @@ from litellm.types.router import (
LiteLLMParamsTypedDict,
ModelGroupInfo,
ModelInfo,
OptionalPreCallChecks,
RetryPolicy,
RouterCacheEnum,
RouterErrors,
@ -153,6 +157,7 @@ from litellm.utils import (
get_llm_provider,
get_secret,
get_utc_datetime,
is_prompt_caching_valid_prompt,
is_region_allowed,
)
@ -248,6 +253,7 @@ class Router:
"cost-based-routing",
"usage-based-routing-v2",
] = "simple-shuffle",
optional_pre_call_checks: Optional[OptionalPreCallChecks] = None,
routing_strategy_args: dict = {}, # just for latency-based
provider_budget_config: Optional[GenericBudgetConfigType] = None,
alerting_config: Optional[AlertingConfig] = None,
@ -542,11 +548,10 @@ class Router:
if RouterBudgetLimiting.should_init_router_budget_limiter(
model_list=model_list, provider_budget_config=self.provider_budget_config
):
self.router_budget_logger = RouterBudgetLimiting(
router_cache=self.cache,
provider_budget_config=self.provider_budget_config,
model_list=self.model_list,
)
if optional_pre_call_checks is not None:
optional_pre_call_checks.append("router_budget_limiting")
else:
optional_pre_call_checks = ["router_budget_limiting"]
self.retry_policy: Optional[RetryPolicy] = None
if retry_policy is not None:
if isinstance(retry_policy, dict):
@ -577,6 +582,10 @@ class Router:
)
self.alerting_config: Optional[AlertingConfig] = alerting_config
if optional_pre_call_checks is not None:
self.add_optional_pre_call_checks(optional_pre_call_checks)
if self.alerting_config is not None:
self._initialize_alerting()
@ -612,6 +621,23 @@ class Router:
f"Dictionary '{fallback_dict}' must have exactly one key, but has {len(fallback_dict)} keys."
)
def add_optional_pre_call_checks(
self, optional_pre_call_checks: Optional[OptionalPreCallChecks]
):
if optional_pre_call_checks is not None:
for pre_call_check in optional_pre_call_checks:
_callback: Optional[CustomLogger] = None
if pre_call_check == "prompt_caching":
_callback = PromptCachingDeploymentCheck(cache=self.cache)
elif pre_call_check == "router_budget_limiting":
_callback = RouterBudgetLimiting(
router_cache=self.cache,
provider_budget_config=self.provider_budget_config,
model_list=self.model_list,
)
if _callback is not None:
litellm.callbacks.append(_callback)
def routing_strategy_init(
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
):
@ -1810,7 +1836,6 @@ class Router:
return await litellm._arealtime(**{**data, "caching": self.cache_responses, **kwargs}) # type: ignore
except Exception as e:
traceback.print_exc()
if self.num_retries > 0:
kwargs["model"] = model
kwargs["messages"] = messages
@ -3261,6 +3286,7 @@ class Router:
litellm_router_instance=self,
deployment_id=id,
)
return tpm_key
except Exception as e:
@ -3699,6 +3725,57 @@ class Router:
).start() # log response
raise e
async def async_callback_filter_deployments(
self,
model: str,
healthy_deployments: List[dict],
messages: Optional[List[AllMessageValues]],
parent_otel_span: Optional[Span],
request_kwargs: Optional[dict] = None,
logging_obj: Optional[LiteLLMLogging] = None,
):
"""
For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
-> makes the calls concurrency-safe, when rpm limits are set for a deployment
Returns:
- None
Raises:
- Rate Limit Exception - If the deployment is over it's tpm/rpm limits
"""
returned_healthy_deployments = healthy_deployments
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger):
try:
returned_healthy_deployments = (
await _callback.async_filter_deployments(
model=model,
healthy_deployments=returned_healthy_deployments,
messages=messages,
request_kwargs=request_kwargs,
parent_otel_span=parent_otel_span,
)
)
except Exception as e:
## LOG FAILURE EVENT
if logging_obj is not None:
asyncio.create_task(
logging_obj.async_failure_handler(
exception=e,
traceback_exception=traceback.format_exc(),
end_time=time.time(),
)
)
## LOGGING
threading.Thread(
target=logging_obj.failure_handler,
args=(e, traceback.format_exc()),
).start() # log response
raise e
return returned_healthy_deployments
def _generate_model_id(self, model_group: str, litellm_params: dict):
"""
Helper function to consistently generate the same id for a deployment
@ -5188,10 +5265,22 @@ class Router:
cooldown_deployments=cooldown_deployments,
)
healthy_deployments = await self.async_callback_filter_deployments(
model=model,
healthy_deployments=healthy_deployments,
messages=(
cast(List[AllMessageValues], messages)
if messages is not None
else None
),
request_kwargs=request_kwargs,
parent_otel_span=parent_otel_span,
)
if self.enable_pre_call_checks and messages is not None:
healthy_deployments = self._pre_call_checks(
model=model,
healthy_deployments=healthy_deployments,
healthy_deployments=cast(List[Dict], healthy_deployments),
messages=messages,
request_kwargs=request_kwargs,
)
@ -5203,13 +5292,13 @@ class Router:
healthy_deployments=healthy_deployments,
)
if self.router_budget_logger:
healthy_deployments = (
await self.router_budget_logger.async_filter_deployments(
healthy_deployments=healthy_deployments,
request_kwargs=request_kwargs,
)
)
# if self.router_budget_logger:
# healthy_deployments = (
# await self.router_budget_logger.async_filter_deployments(
# healthy_deployments=healthy_deployments,
# request_kwargs=request_kwargs,
# )
# )
if len(healthy_deployments) == 0:
exception = await async_raise_no_deployment_exception(

View file

@ -26,13 +26,14 @@ import litellm
from litellm._logging import verbose_router_logger
from litellm.caching.caching import DualCache
from litellm.caching.redis_cache import RedisPipelineIncrementOperation
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.custom_logger import CustomLogger, Span
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.router_strategy.tag_based_routing import _get_tags_from_request_kwargs
from litellm.router_utils.cooldown_callbacks import (
_get_prometheus_logger_from_callbacks,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.router import (
DeploymentTypedDict,
GenericBudgetConfigType,
@ -42,13 +43,6 @@ from litellm.types.router import (
)
from litellm.types.utils import BudgetConfig, StandardLoggingPayload
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
DEFAULT_REDIS_SYNC_INTERVAL = 1
@ -79,9 +73,12 @@ class RouterBudgetLimiting(CustomLogger):
async def async_filter_deployments(
self,
healthy_deployments: Union[List[Dict[str, Any]], Dict[str, Any]],
request_kwargs: Optional[Dict] = None,
):
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None, # type: ignore
) -> List[dict]:
"""
Filter out deployments that have exceeded their provider budget limit.
@ -102,11 +99,6 @@ class RouterBudgetLimiting(CustomLogger):
potential_deployments: List[Dict] = []
# Extract the parent OpenTelemetry span for tracing
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
request_kwargs
)
cache_keys, provider_configs, deployment_configs = (
await self._async_get_cache_keys_for_router_budget_limiting(
healthy_deployments=healthy_deployments,

View file

@ -91,8 +91,8 @@ def _get_prometheus_logger_from_callbacks() -> Optional[PrometheusLogger]:
for _callback in litellm._async_success_callback:
if isinstance(_callback, PrometheusLogger):
return _callback
for _callback in litellm.callbacks:
if isinstance(_callback, PrometheusLogger):
return _callback
for global_callback in litellm.callbacks:
if isinstance(global_callback, PrometheusLogger):
return global_callback
return None

View file

@ -0,0 +1,99 @@
"""
Check if prompt caching is valid for a given deployment
Route to previously cached model id, if valid
"""
from typing import List, Optional, cast
from litellm import verbose_logger
from litellm.caching.dual_cache import DualCache
from litellm.integrations.custom_logger import CustomLogger, Span
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import CallTypes, StandardLoggingPayload
from litellm.utils import is_prompt_caching_valid_prompt
from ..prompt_caching_cache import PromptCachingCache
class PromptCachingDeploymentCheck(CustomLogger):
def __init__(self, cache: DualCache):
self.cache = cache
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None,
) -> List[dict]:
if messages is not None and is_prompt_caching_valid_prompt(
messages=messages,
model=model,
): # prompt > 1024 tokens
prompt_cache = PromptCachingCache(
cache=self.cache,
)
model_id_dict = await prompt_cache.async_get_model_id(
messages=cast(List[AllMessageValues], messages),
tools=None,
)
if model_id_dict is not None:
model_id = model_id_dict["model_id"]
for deployment in healthy_deployments:
if deployment["model_info"]["id"] == model_id:
return [deployment]
return healthy_deployments
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if standard_logging_object is None:
return
call_type = standard_logging_object["call_type"]
if (
call_type != CallTypes.completion.value
and call_type != CallTypes.acompletion.value
): # only use prompt caching for completion calls
verbose_logger.debug(
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, CALL TYPE IS NOT COMPLETION"
)
return
model = standard_logging_object["model"]
messages = standard_logging_object["messages"]
model_id = standard_logging_object["model_id"]
if messages is None or not isinstance(messages, list):
verbose_logger.debug(
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MESSAGES IS NOT A LIST"
)
return
if model_id is None:
verbose_logger.debug(
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MODEL ID IS NONE"
)
return
## PROMPT CACHING - cache model id, if prompt caching valid prompt + provider
if is_prompt_caching_valid_prompt(
model=model,
messages=cast(List[AllMessageValues], messages),
):
cache = PromptCachingCache(
cache=self.cache,
)
await cache.async_add_model_id(
model_id=model_id,
messages=messages,
tools=None, # [TODO]: add tools once standard_logging_object supports it
)
return

View file

@ -172,22 +172,3 @@ class PromptCachingCache:
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
return self.cache.get_cache(cache_key)
async def async_get_prompt_caching_deployment(
self,
router: litellm_router,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[dict]:
model_id_dict = await self.async_get_model_id(
messages=messages,
tools=tools,
)
if model_id_dict is not None:
healthy_deployment_pydantic_obj = router.get_deployment(
model_id=model_id_dict["model_id"]
)
if healthy_deployment_pydantic_obj is not None:
return healthy_deployment_pydantic_obj.model_dump(exclude_none=True)
return None

View file

@ -667,3 +667,6 @@ class GenericBudgetWindowDetails(BaseModel):
spend_key: str
start_time_key: str
ttl_seconds: int
OptionalPreCallChecks = List[Literal["prompt_caching", "router_budget_limiting"]]

View file

@ -653,9 +653,9 @@ async def test_router_prompt_caching_model_stored(
@pytest.mark.asyncio()
@pytest.mark.skip(
reason="BETA FEATURE - skipping since this led to a latency impact, beta feature that is not used as yet"
)
# @pytest.mark.skip(
# reason="BETA FEATURE - skipping since this led to a latency impact, beta feature that is not used as yet"
# )
async def test_router_with_prompt_caching(anthropic_messages):
"""
if prompt caching supported model called with prompt caching valid prompt,
@ -672,15 +672,18 @@ async def test_router_with_prompt_caching(anthropic_messages):
"litellm_params": {
"model": "anthropic/claude-3-5-sonnet-20240620",
"api_key": os.environ.get("ANTHROPIC_API_KEY"),
"mock_response": "The sky is blue.",
},
},
{
"model_name": "claude-model",
"litellm_params": {
"model": "anthropic.claude-3-5-sonnet-20241022-v2:0",
"mock_response": "The sky is green.",
},
},
]
],
optional_pre_call_checks=["prompt_caching"],
)
response = await router.acompletion(
@ -699,6 +702,7 @@ async def test_router_with_prompt_caching(anthropic_messages):
cached_model_id = cache.get_model_id(messages=anthropic_messages, tools=None)
assert cached_model_id is not None
prompt_caching_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
messages=anthropic_messages, tools=None
)
@ -709,13 +713,7 @@ async def test_router_with_prompt_caching(anthropic_messages):
{"role": "user", "content": "What is the weather in SF?"}
]
pc_deployment = await cache.async_get_prompt_caching_deployment(
router=router,
messages=new_messages,
tools=None,
)
assert pc_deployment is not None
for _ in range(20):
response = await router.acompletion(
messages=new_messages,
model="claude-model",

View file

@ -304,7 +304,7 @@ async def test_prometheus_metric_tracking():
await asyncio.sleep(2.5)
# Verify the mock was called correctly
mock_prometheus.track_provider_remaining_budget.assert_called_once()
mock_prometheus.track_provider_remaining_budget.assert_called()
@pytest.mark.asyncio

View file

@ -1058,3 +1058,28 @@ def test_has_default_fallbacks(model_list, has_default_fallbacks, expected_resul
),
)
assert router._has_default_fallbacks() is expected_result
def test_add_optional_pre_call_checks(model_list):
router = Router(model_list=model_list)
router.add_optional_pre_call_checks(["prompt_caching"])
assert len(litellm.callbacks) > 0
@pytest.mark.asyncio
async def test_async_callback_filter_deployments(model_list):
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
router = Router(model_list=model_list)
healthy_deployments = router.get_model_list(model_name="gpt-3.5-turbo")
new_healthy_deployments = await router.async_callback_filter_deployments(
model="gpt-3.5-turbo",
healthy_deployments=healthy_deployments,
messages=[],
parent_otel_span=None,
)
assert len(new_healthy_deployments) == len(healthy_deployments)