redis otel tracing + async support for latency routing (#6452)

* docs(exception_mapping.md): add missing exception types

Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183

* fix(main.py): register custom model pricing with specific key

Ensure custom model pricing is registered to the specific model+provider key combination

* test: make testing more robust for custom pricing

* fix(redis_cache.py): instrument otel logging for sync redis calls

ensures complete coverage for all redis cache calls

* refactor: pass parent_otel_span for redis caching calls in router

allows for more observability into what calls are causing latency issues

* test: update tests with new params

* refactor: ensure e2e otel tracing for router

* refactor(router.py): add more otel tracing acrosss router

catch all latency issues for router requests

* fix: fix linting error

* fix(router.py): fix linting error

* fix: fix test

* test: fix tests

* fix(dual_cache.py): pass ttl to redis cache

* fix: fix param
This commit is contained in:
Krish Dholakia 2024-10-28 21:52:12 -07:00 committed by GitHub
parent d9e7818e6b
commit 4f8a3fd4cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 559 additions and 147 deletions

View file

@ -44,8 +44,7 @@ class ServiceLogging(CustomLogger):
"""
Handles both sync and async monitoring by checking for existing event loop.
"""
# if service == ServiceTypes.REDIS:
# print(f"SYNC service: {service}, call_type: {call_type}")
if self.mock_testing:
self.mock_testing_sync_success_hook += 1
@ -112,8 +111,7 @@ class ServiceLogging(CustomLogger):
"""
- For counting if the redis, postgres call is successful
"""
# if service == ServiceTypes.REDIS:
# print(f"service: {service}, call_type: {call_type}")
if self.mock_testing:
self.mock_testing_async_success_hook += 1

View file

@ -8,7 +8,14 @@ Has 4 methods:
- async_get_cache
"""
from typing import Optional
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class BaseCache:

View file

@ -1,10 +1,17 @@
import json
from typing import Optional
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import print_verbose
from .base_cache import BaseCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class DiskCache(BaseCache):
def __init__(self, disk_cache_dir: Optional[str] = None):

View file

@ -9,7 +9,7 @@ Has 4 primary methods:
"""
import traceback
from typing import List, Optional
from typing import TYPE_CHECKING, Any, List, Optional
import litellm
from litellm._logging import print_verbose, verbose_logger
@ -18,6 +18,13 @@ from .base_cache import BaseCache
from .in_memory_cache import InMemoryCache
from .redis_cache import RedisCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class DualCache(BaseCache):
"""
@ -90,7 +97,13 @@ class DualCache(BaseCache):
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
raise e
def get_cache(self, key, local_only: bool = False, **kwargs):
def get_cache(
self,
key,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
):
# Try to fetch from in-memory cache first
try:
result = None
@ -102,7 +115,9 @@ class DualCache(BaseCache):
if result is None and self.redis_cache is not None and local_only is False:
# If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.get_cache(key, **kwargs)
redis_result = self.redis_cache.get_cache(
key, parent_otel_span=parent_otel_span
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
@ -115,7 +130,13 @@ class DualCache(BaseCache):
except Exception:
verbose_logger.error(traceback.format_exc())
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
def batch_get_cache(
self,
keys: list,
parent_otel_span: Optional[Span],
local_only: bool = False,
**kwargs,
):
try:
result = [None for _ in range(len(keys))]
if self.in_memory_cache is not None:
@ -133,7 +154,9 @@ class DualCache(BaseCache):
key for key, value in zip(keys, result) if value is None
]
# If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.batch_get_cache(sublist_keys, **kwargs)
redis_result = self.redis_cache.batch_get_cache(
sublist_keys, parent_otel_span=parent_otel_span
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
for key in redis_result:
@ -147,7 +170,13 @@ class DualCache(BaseCache):
except Exception:
verbose_logger.error(traceback.format_exc())
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
async def async_get_cache(
self,
key,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
):
# Try to fetch from in-memory cache first
try:
print_verbose(
@ -165,7 +194,9 @@ class DualCache(BaseCache):
if result is None and self.redis_cache is not None and local_only is False:
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_get_cache(key, **kwargs)
redis_result = await self.redis_cache.async_get_cache(
key, parent_otel_span=parent_otel_span
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
@ -181,7 +212,11 @@ class DualCache(BaseCache):
verbose_logger.error(traceback.format_exc())
async def async_batch_get_cache(
self, keys: list, local_only: bool = False, **kwargs
self,
keys: list,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
):
try:
result = [None for _ in range(len(keys))]
@ -202,7 +237,7 @@ class DualCache(BaseCache):
]
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_batch_get_cache(
sublist_keys, **kwargs
sublist_keys, parent_otel_span=parent_otel_span
)
if redis_result is not None:
@ -260,7 +295,12 @@ class DualCache(BaseCache):
)
async def async_increment_cache(
self, key, value: float, local_only: bool = False, **kwargs
self,
key,
value: float,
parent_otel_span: Optional[Span],
local_only: bool = False,
**kwargs,
) -> float:
"""
Key - the key in cache
@ -277,7 +317,12 @@ class DualCache(BaseCache):
)
if self.redis_cache is not None and local_only is False:
result = await self.redis_cache.async_increment(key, value, **kwargs)
result = await self.redis_cache.async_increment(
key,
value,
parent_otel_span=parent_otel_span,
ttl=kwargs.get("ttl", None),
)
return result
except Exception as e:

View file

@ -13,6 +13,7 @@ import asyncio
import inspect
import json
import time
import traceback
from datetime import timedelta
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
@ -25,14 +26,17 @@ from litellm.types.utils import all_litellm_params
from .base_cache import BaseCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from redis.asyncio import Redis
from redis.asyncio.client import Pipeline
pipeline = Pipeline
async_redis_client = Redis
Span = _Span
else:
pipeline = Any
async_redis_client = Any
Span = Any
class RedisCache(BaseCache):
@ -524,7 +528,11 @@ class RedisCache(BaseCache):
await self.flush_cache_buffer() # logging done in here
async def async_increment(
self, key, value: float, ttl: Optional[int] = None, **kwargs
self,
key,
value: float,
ttl: Optional[int] = None,
parent_otel_span: Optional[Span] = None,
) -> float:
from redis.asyncio import Redis
@ -552,7 +560,7 @@ class RedisCache(BaseCache):
call_type="async_increment",
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
parent_otel_span=parent_otel_span,
)
)
return result
@ -568,7 +576,7 @@ class RedisCache(BaseCache):
call_type="async_increment",
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
parent_otel_span=parent_otel_span,
)
)
verbose_logger.error(
@ -601,7 +609,7 @@ class RedisCache(BaseCache):
cached_response = ast.literal_eval(cached_response)
return cached_response
def get_cache(self, key, **kwargs):
def get_cache(self, key, parent_otel_span: Optional[Span] = None, **kwargs):
try:
key = self.check_and_fix_namespace(key=key)
print_verbose(f"Get Redis Cache: key: {key}")
@ -615,6 +623,7 @@ class RedisCache(BaseCache):
call_type="get_cache",
start_time=start_time,
end_time=end_time,
parent_otel_span=parent_otel_span,
)
print_verbose(
f"Got Redis Cache: key: {key}, cached_response {cached_response}"
@ -626,11 +635,12 @@ class RedisCache(BaseCache):
"litellm.caching.caching: get() - Got exception from REDIS: ", e
)
def batch_get_cache(self, key_list) -> dict:
def batch_get_cache(self, key_list, parent_otel_span: Optional[Span]) -> dict:
"""
Use Redis for bulk read operations
"""
key_value_dict = {}
try:
_keys = []
for cache_key in key_list:
@ -646,6 +656,7 @@ class RedisCache(BaseCache):
call_type="batch_get_cache",
start_time=start_time,
end_time=end_time,
parent_otel_span=parent_otel_span,
)
# Associate the results back with their keys.
@ -662,7 +673,9 @@ class RedisCache(BaseCache):
print_verbose(f"Error occurred in pipeline read - {str(e)}")
return key_value_dict
async def async_get_cache(self, key, **kwargs):
async def async_get_cache(
self, key, parent_otel_span: Optional[Span] = None, **kwargs
):
from redis.asyncio import Redis
_redis_client: Redis = self.init_async_client() # type: ignore
@ -686,7 +699,7 @@ class RedisCache(BaseCache):
call_type="async_get_cache",
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
parent_otel_span=parent_otel_span,
event_metadata={"key": key},
)
)
@ -703,7 +716,7 @@ class RedisCache(BaseCache):
call_type="async_get_cache",
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
parent_otel_span=parent_otel_span,
event_metadata={"key": key},
)
)
@ -712,10 +725,13 @@ class RedisCache(BaseCache):
f"litellm.caching.caching: async get() - Got exception from REDIS: {str(e)}"
)
async def async_batch_get_cache(self, key_list) -> dict:
async def async_batch_get_cache(
self, key_list: List[str], parent_otel_span: Optional[Span] = None
) -> dict:
"""
Use Redis for bulk read operations
"""
_redis_client = await self.init_async_client()
key_value_dict = {}
start_time = time.time()
@ -737,6 +753,7 @@ class RedisCache(BaseCache):
call_type="async_batch_get_cache",
start_time=start_time,
end_time=end_time,
parent_otel_span=parent_otel_span,
)
)
@ -764,6 +781,7 @@ class RedisCache(BaseCache):
call_type="async_batch_get_cache",
start_time=start_time,
end_time=end_time,
parent_otel_span=parent_otel_span,
)
)
print_verbose(f"Error occurred in pipeline read - {str(e)}")

View file

@ -268,6 +268,7 @@ class SlackAlerting(CustomBatchLogger):
SlackAlertingCacheKeys.failed_requests_key.value,
),
value=1,
parent_otel_span=None, # no attached request, this is a background operation
)
return_val += 1
@ -279,6 +280,7 @@ class SlackAlerting(CustomBatchLogger):
deployment_metrics.id, SlackAlertingCacheKeys.latency_key.value
),
value=deployment_metrics.latency_per_output_token,
parent_otel_span=None, # no attached request, this is a background operation
)
return_val += 1
@ -1518,7 +1520,8 @@ Model Info:
report_sent_bool = False
report_sent = await self.internal_usage_cache.async_get_cache(
key=SlackAlertingCacheKeys.report_sent_key.value
key=SlackAlertingCacheKeys.report_sent_key.value,
parent_otel_span=None,
) # None | float
current_time = time.time()

View file

@ -3,7 +3,7 @@
import os
import traceback
from datetime import datetime as datetimeObj
from typing import Any, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, Union
import dotenv
from pydantic import BaseModel
@ -21,6 +21,13 @@ from litellm.types.utils import (
StandardLoggingPayload,
)
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
# Class variables or attributes
@ -62,7 +69,9 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks).
"""
async def async_pre_call_check(self, deployment: dict) -> Optional[dict]:
async def async_pre_call_check(
self, deployment: dict, parent_otel_span: Optional[Span]
) -> Optional[dict]:
pass
def pre_call_check(self, deployment: dict) -> Optional[dict]:

View file

@ -8,7 +8,12 @@ import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.services import ServiceLoggerPayload
from litellm.types.utils import StandardLoggingPayload
from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
StandardLoggingPayload,
)
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
@ -136,12 +141,12 @@ class OpenTelemetry(CustomLogger):
_end_time_ns = 0
if isinstance(start_time, float):
_start_time_ns = int(int(start_time) * 1e9)
_start_time_ns = int(start_time * 1e9)
else:
_start_time_ns = self._to_ns(start_time)
if isinstance(end_time, float):
_end_time_ns = int(int(end_time) * 1e9)
_end_time_ns = int(end_time * 1e9)
else:
_end_time_ns = self._to_ns(end_time)
@ -276,6 +281,21 @@ class OpenTelemetry(CustomLogger):
# End Parent OTEL Sspan
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
):
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
parent_otel_span = user_api_key_dict.parent_otel_span
if parent_otel_span is not None:
parent_otel_span.set_status(Status(StatusCode.OK))
# End Parent OTEL Sspan
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
@ -314,8 +334,8 @@ class OpenTelemetry(CustomLogger):
span.end(end_time=self._to_ns(end_time))
if parent_otel_span is not None:
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
# if parent_otel_span is not None:
# parent_otel_span.end(end_time=self._to_ns(datetime.now()))
def _handle_failure(self, kwargs, response_obj, start_time, end_time):
from opentelemetry.trace import Status, StatusCode
@ -808,12 +828,12 @@ class OpenTelemetry(CustomLogger):
end_time = logging_payload.end_time
if isinstance(start_time, float):
_start_time_ns = int(int(start_time) * 1e9)
_start_time_ns = int(start_time * 1e9)
else:
_start_time_ns = self._to_ns(start_time)
if isinstance(end_time, float):
_end_time_ns = int(int(end_time) * 1e9)
_end_time_ns = int(end_time * 1e9)
else:
_end_time_ns = self._to_ns(end_time)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -260,6 +260,7 @@ async def user_api_key_auth( # noqa: PLR0915
headers=request.headers
),
)
### USER-DEFINED AUTH FUNCTION ###
if user_custom_auth is not None:
response = await user_custom_auth(request=request, api_key=api_key) # type: ignore

View file

@ -28,7 +28,9 @@ class _PROXY_MaxBudgetLimiter(CustomLogger):
try:
self.print_verbose("Inside Max Budget Limiter Pre-Call Hook")
cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"
user_row = cache.get_cache(cache_key)
user_row = await cache.async_get_cache(
cache_key, parent_otel_span=user_api_key_dict.parent_otel_span
)
if user_row is None: # value not yet cached
return
max_budget = user_row["max_budget"]

View file

@ -235,7 +235,7 @@ class InternalUsageCache:
return await self.dual_cache.async_get_cache(
key=key,
local_only=local_only,
litellm_parent_otel_span=litellm_parent_otel_span,
parent_otel_span=litellm_parent_otel_span,
**kwargs,
)
@ -281,7 +281,7 @@ class InternalUsageCache:
key=key,
value=value,
local_only=local_only,
litellm_parent_otel_span=litellm_parent_otel_span,
parent_otel_span=litellm_parent_otel_span,
**kwargs,
)
@ -367,7 +367,10 @@ class ProxyLogging:
llm_router=llm_router
) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
if "daily_reports" in self.slack_alerting_instance.alert_types:
if (
self.slack_alerting_instance is not None
and "daily_reports" in self.slack_alerting_instance.alert_types
):
asyncio.create_task(
self.slack_alerting_instance._run_scheduled_daily_report(
llm_router=llm_router

View file

@ -25,6 +25,7 @@ import uuid
from collections import defaultdict
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
@ -50,6 +51,7 @@ from litellm._logging import verbose_router_logger
from litellm.assistants.main import AssistantDeleted
from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
@ -124,6 +126,7 @@ from litellm.types.router import (
updateDeployment,
updateLiteLLMParams,
)
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from litellm.types.utils import OPENAI_RESPONSE_HEADERS
from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.utils import (
@ -140,6 +143,13 @@ from litellm.utils import (
from .router_utils.pattern_match_deployments import PatternMatchRouter
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class RoutingArgs(enum.Enum):
ttl = 60 # 1min (RPM/TPM expire key)
@ -293,6 +303,8 @@ class Router:
```
"""
from litellm._service_logger import ServiceLogging
if semaphore:
self.semaphore = semaphore
self.set_verbose = set_verbose
@ -494,7 +506,7 @@ class Router:
f"Routing context window fallbacks: {self.context_window_fallbacks}\n\n"
f"Router Redis Caching={self.cache.redis_cache}\n"
)
self.service_logger_obj = ServiceLogging()
self.routing_strategy_args = routing_strategy_args
self.retry_policy: Optional[RetryPolicy] = None
if retry_policy is not None:
@ -762,10 +774,23 @@ class Router:
request_priority = kwargs.get("priority") or self.default_priority
start_time = time.time()
if request_priority is not None and isinstance(request_priority, int):
response = await self.schedule_acompletion(**kwargs)
else:
response = await self.async_function_with_fallbacks(**kwargs)
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.ROUTER,
duration=_duration,
call_type="acompletion",
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
)
)
return response
except Exception as e:
@ -793,15 +818,32 @@ class Router:
verbose_router_logger.debug(
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
start_time = time.time()
deployment = await self.async_get_available_deployment(
model=model,
messages=messages,
specific_deployment=kwargs.pop("specific_deployment", None),
request_kwargs=kwargs,
)
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.ROUTER,
duration=_duration,
call_type="async_get_available_deployment",
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
)
)
# debug how often this deployment picked
self._track_deployment_metrics(deployment=deployment)
self._track_deployment_metrics(
deployment=deployment, parent_otel_span=parent_otel_span
)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
data = deployment["litellm_params"].copy()
@ -846,12 +888,16 @@ class Router:
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment, logging_obj=logging_obj
deployment=deployment,
logging_obj=logging_obj,
parent_otel_span=parent_otel_span,
)
response = await _response
else:
await self.async_routing_strategy_pre_call_checks(
deployment=deployment, logging_obj=logging_obj
deployment=deployment,
logging_obj=logging_obj,
parent_otel_span=parent_otel_span,
)
response = await _response
@ -872,7 +918,11 @@ class Router:
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
)
# debug how often this deployment picked
self._track_deployment_metrics(deployment=deployment, response=response)
self._track_deployment_metrics(
deployment=deployment,
response=response,
parent_otel_span=parent_otel_span,
)
return response
except Exception as e:
@ -1212,6 +1262,7 @@ class Router:
stream=False,
**kwargs,
):
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
### FLOW ITEM ###
_request_id = str(uuid.uuid4())
item = FlowItem(
@ -1232,7 +1283,7 @@ class Router:
while curr_time < end_time:
_healthy_deployments, _ = await self._async_get_healthy_deployments(
model=model
model=model, parent_otel_span=parent_otel_span
)
make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
id=item.request_id,
@ -1353,6 +1404,7 @@ class Router:
verbose_router_logger.debug(
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
deployment = await self.async_get_available_deployment(
model=model,
messages=[{"role": "user", "content": "prompt"}],
@ -1395,11 +1447,13 @@ class Router:
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
await self.async_routing_strategy_pre_call_checks(
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response
self.success_calls[model_name] += 1
@ -1465,6 +1519,7 @@ class Router:
verbose_router_logger.debug(
f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
deployment = await self.async_get_available_deployment(
model=model,
messages=[{"role": "user", "content": "prompt"}],
@ -1505,11 +1560,13 @@ class Router:
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
await self.async_routing_strategy_pre_call_checks(
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response
self.success_calls[model_name] += 1
@ -1861,6 +1918,7 @@ class Router:
verbose_router_logger.debug(
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
deployment = await self.async_get_available_deployment(
model=model,
messages=[{"role": "user", "content": prompt}],
@ -1903,11 +1961,13 @@ class Router:
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
await self.async_routing_strategy_pre_call_checks(
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response
self.success_calls[model_name] += 1
@ -1958,6 +2018,7 @@ class Router:
verbose_router_logger.debug(
f"Inside _aadapter_completion()- model: {model}; kwargs: {kwargs}"
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
deployment = await self.async_get_available_deployment(
model=model,
messages=[{"role": "user", "content": "default text"}],
@ -2000,11 +2061,13 @@ class Router:
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response # type: ignore
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
await self.async_routing_strategy_pre_call_checks(
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response # type: ignore
self.success_calls[model_name] += 1
@ -2128,6 +2191,7 @@ class Router:
verbose_router_logger.debug(
f"Inside _aembedding()- model: {model}; kwargs: {kwargs}"
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
deployment = await self.async_get_available_deployment(
model=model,
input=input,
@ -2168,11 +2232,13 @@ class Router:
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
await self.async_routing_strategy_pre_call_checks(
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response
self.success_calls[model_name] += 1
@ -2223,6 +2289,7 @@ class Router:
verbose_router_logger.debug(
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
deployment = await self.async_get_available_deployment(
model=model,
messages=[{"role": "user", "content": "files-api-fake-text"}],
@ -2273,11 +2340,13 @@ class Router:
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response # type: ignore
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
await self.async_routing_strategy_pre_call_checks(
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response # type: ignore
self.success_calls[model_name] += 1
@ -2327,6 +2396,7 @@ class Router:
verbose_router_logger.debug(
f"Inside _acreate_batch()- model: {model}; kwargs: {kwargs}"
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
deployment = await self.async_get_available_deployment(
model=model,
messages=[{"role": "user", "content": "files-api-fake-text"}],
@ -2389,11 +2459,13 @@ class Router:
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response # type: ignore
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
await self.async_routing_strategy_pre_call_checks(
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response # type: ignore
self.success_calls[model_name] += 1
@ -2702,12 +2774,14 @@ class Router:
)
return response
except Exception as new_exception:
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
verbose_router_logger.error(
"litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format(
str(new_exception),
traceback.format_exc(),
await _async_get_cooldown_deployments_with_debug_info(
litellm_router_instance=self
litellm_router_instance=self,
parent_otel_span=parent_otel_span,
),
)
)
@ -2779,12 +2853,13 @@ class Router:
Context_Policy_Fallbacks={content_policy_fallbacks}",
)
async def async_function_with_retries(self, *args, **kwargs):
async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915
verbose_router_logger.debug(
f"Inside async function with retries: args - {args}; kwargs - {kwargs}"
)
original_function = kwargs.pop("original_function")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
context_window_fallbacks = kwargs.pop(
"context_window_fallbacks", self.context_window_fallbacks
)
@ -2822,6 +2897,7 @@ class Router:
_healthy_deployments, _all_deployments = (
await self._async_get_healthy_deployments(
model=kwargs.get("model") or "",
parent_otel_span=parent_otel_span,
)
)
@ -2879,6 +2955,7 @@ class Router:
_healthy_deployments, _ = (
await self._async_get_healthy_deployments(
model=_model,
parent_otel_span=parent_otel_span,
)
)
else:
@ -3217,8 +3294,10 @@ class Router:
if _model is None:
raise e # re-raise error, if model can't be determined for loadbalancing
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
_healthy_deployments, _all_deployments = self._get_healthy_deployments(
model=_model,
parent_otel_span=parent_otel_span,
)
# raises an exception if this error should not be retries
@ -3260,8 +3339,10 @@ class Router:
if _model is None:
raise e # re-raise error, if model can't be determined for loadbalancing
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
_healthy_deployments, _ = self._get_healthy_deployments(
model=_model,
parent_otel_span=parent_otel_span,
)
remaining_retries = num_retries - current_attempt
_timeout = self._time_to_sleep_before_retry(
@ -3323,9 +3404,13 @@ class Router:
# ------------
# update cache
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
## TPM
await self.cache.async_increment_cache(
key=tpm_key, value=total_tokens, ttl=RoutingArgs.ttl.value
key=tpm_key,
value=total_tokens,
parent_otel_span=parent_otel_span,
ttl=RoutingArgs.ttl.value,
)
increment_deployment_successes_for_current_minute(
@ -3474,7 +3559,9 @@ class Router:
except Exception as e:
raise e
def _update_usage(self, deployment_id: str) -> int:
def _update_usage(
self, deployment_id: str, parent_otel_span: Optional[Span]
) -> int:
"""
Update deployment rpm for that minute
@ -3483,7 +3570,9 @@ class Router:
"""
rpm_key = deployment_id
request_count = self.cache.get_cache(key=rpm_key, local_only=True)
request_count = self.cache.get_cache(
key=rpm_key, parent_otel_span=parent_otel_span, local_only=True
)
if request_count is None:
request_count = 1
self.cache.set_cache(
@ -3591,7 +3680,7 @@ class Router:
)
return False
def _get_healthy_deployments(self, model: str):
def _get_healthy_deployments(self, model: str, parent_otel_span: Optional[Span]):
_all_deployments: list = []
try:
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
@ -3602,7 +3691,9 @@ class Router:
except Exception:
pass
unhealthy_deployments = _get_cooldown_deployments(litellm_router_instance=self)
unhealthy_deployments = _get_cooldown_deployments(
litellm_router_instance=self, parent_otel_span=parent_otel_span
)
healthy_deployments: list = []
for deployment in _all_deployments:
if deployment["model_info"]["id"] in unhealthy_deployments:
@ -3613,7 +3704,7 @@ class Router:
return healthy_deployments, _all_deployments
async def _async_get_healthy_deployments(
self, model: str
self, model: str, parent_otel_span: Optional[Span]
) -> Tuple[List[Dict], List[Dict]]:
"""
Returns Tuple of:
@ -3632,7 +3723,7 @@ class Router:
pass
unhealthy_deployments = await _async_get_cooldown_deployments(
litellm_router_instance=self
litellm_router_instance=self, parent_otel_span=parent_otel_span
)
healthy_deployments: list = []
for deployment in _all_deployments:
@ -3659,7 +3750,10 @@ class Router:
_callback.pre_call_check(deployment)
async def async_routing_strategy_pre_call_checks(
self, deployment: dict, logging_obj: Optional[LiteLLMLogging] = None
self,
deployment: dict,
parent_otel_span: Optional[Span],
logging_obj: Optional[LiteLLMLogging] = None,
):
"""
For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
@ -3675,7 +3769,7 @@ class Router:
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger):
try:
await _callback.async_pre_call_check(deployment)
await _callback.async_pre_call_check(deployment, parent_otel_span)
except litellm.RateLimitError as e:
## LOG FAILURE EVENT
if logging_obj is not None:
@ -4646,14 +4740,19 @@ class Router:
The appropriate client based on the given client_type and kwargs.
"""
model_id = deployment["model_info"]["id"]
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(kwargs)
if client_type == "max_parallel_requests":
cache_key = "{}_max_parallel_requests_client".format(model_id)
client = self.cache.get_cache(key=cache_key, local_only=True)
client = self.cache.get_cache(
key=cache_key, local_only=True, parent_otel_span=parent_otel_span
)
return client
elif client_type == "async":
if kwargs.get("stream") is True:
cache_key = f"{model_id}_stream_async_client"
client = self.cache.get_cache(key=cache_key, local_only=True)
client = self.cache.get_cache(
key=cache_key, local_only=True, parent_otel_span=parent_otel_span
)
if client is None:
"""
Re-initialize the client
@ -4661,11 +4760,17 @@ class Router:
InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
client = self.cache.get_cache(key=cache_key, local_only=True)
client = self.cache.get_cache(
key=cache_key,
local_only=True,
parent_otel_span=parent_otel_span,
)
return client
else:
cache_key = f"{model_id}_async_client"
client = self.cache.get_cache(key=cache_key, local_only=True)
client = self.cache.get_cache(
key=cache_key, local_only=True, parent_otel_span=parent_otel_span
)
if client is None:
"""
Re-initialize the client
@ -4673,12 +4778,18 @@ class Router:
InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
client = self.cache.get_cache(key=cache_key, local_only=True)
client = self.cache.get_cache(
key=cache_key,
local_only=True,
parent_otel_span=parent_otel_span,
)
return client
else:
if kwargs.get("stream") is True:
cache_key = f"{model_id}_stream_client"
client = self.cache.get_cache(key=cache_key)
client = self.cache.get_cache(
key=cache_key, parent_otel_span=parent_otel_span
)
if client is None:
"""
Re-initialize the client
@ -4686,11 +4797,15 @@ class Router:
InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
client = self.cache.get_cache(key=cache_key)
client = self.cache.get_cache(
key=cache_key, parent_otel_span=parent_otel_span
)
return client
else:
cache_key = f"{model_id}_client"
client = self.cache.get_cache(key=cache_key)
client = self.cache.get_cache(
key=cache_key, parent_otel_span=parent_otel_span
)
if client is None:
"""
Re-initialize the client
@ -4698,7 +4813,9 @@ class Router:
InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
client = self.cache.get_cache(key=cache_key)
client = self.cache.get_cache(
key=cache_key, parent_otel_span=parent_otel_span
)
return client
def _pre_call_checks( # noqa: PLR0915
@ -4738,13 +4855,17 @@ class Router:
_context_window_error = False
_potential_error_str = ""
_rate_limit_error = False
parent_otel_span = _get_parent_otel_span_from_kwargs(request_kwargs)
## get model group RPM ##
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
rpm_key = f"{model}:rpm:{current_minute}"
model_group_cache = (
self.cache.get_cache(key=rpm_key, local_only=True) or {}
self.cache.get_cache(
key=rpm_key, local_only=True, parent_otel_span=parent_otel_span
)
or {}
) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
for idx, deployment in enumerate(_returned_deployments):
# see if we have the info for this model
@ -4783,7 +4904,10 @@ class Router:
## RPM CHECK ##
### get local router cache ###
current_request_cache_local = (
self.cache.get_cache(key=model_id, local_only=True) or 0
self.cache.get_cache(
key=model_id, local_only=True, parent_otel_span=parent_otel_span
)
or 0
)
### get usage based cache ###
if (
@ -5002,6 +5126,7 @@ class Router:
self.routing_strategy != "usage-based-routing-v2"
and self.routing_strategy != "simple-shuffle"
and self.routing_strategy != "cost-based-routing"
and self.routing_strategy != "latency-based-routing"
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
return self.get_available_deployment(
model=model,
@ -5011,6 +5136,7 @@ class Router:
request_kwargs=request_kwargs,
)
try:
parent_otel_span = _get_parent_otel_span_from_kwargs(request_kwargs)
model, healthy_deployments = self._common_checks_available_deployment(
model=model,
messages=messages,
@ -5021,7 +5147,7 @@ class Router:
return healthy_deployments
cooldown_deployments = await _async_get_cooldown_deployments(
litellm_router_instance=self
litellm_router_instance=self, parent_otel_span=parent_otel_span
)
verbose_router_logger.debug(
f"async cooldown deployments: {cooldown_deployments}"
@ -5059,16 +5185,18 @@ class Router:
_allowed_model_region = "n/a"
model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown(
model_ids=model_ids
model_ids=model_ids, parent_otel_span=parent_otel_span
)
_cooldown_list = _get_cooldown_deployments(
litellm_router_instance=self, parent_otel_span=parent_otel_span
)
_cooldown_list = _get_cooldown_deployments(litellm_router_instance=self)
raise RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
enable_pre_call_checks=self.enable_pre_call_checks,
cooldown_list=_cooldown_list,
)
start_time = time.time()
if (
self.routing_strategy == "usage-based-routing-v2"
and self.lowesttpm_logger_v2 is not None
@ -5093,6 +5221,19 @@ class Router:
input=input,
)
)
elif (
self.routing_strategy == "latency-based-routing"
and self.lowestlatency_logger is not None
):
deployment = (
await self.lowestlatency_logger.async_get_available_deployments(
model_group=model,
healthy_deployments=healthy_deployments, # type: ignore
messages=messages,
input=input,
request_kwargs=request_kwargs,
)
)
elif self.routing_strategy == "simple-shuffle":
return simple_shuffle(
llm_router_instance=self,
@ -5107,9 +5248,11 @@ class Router:
)
model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown(
model_ids=model_ids
model_ids=model_ids, parent_otel_span=parent_otel_span
)
_cooldown_list = _get_cooldown_deployments(
litellm_router_instance=self, parent_otel_span=parent_otel_span
)
_cooldown_list = _get_cooldown_deployments(litellm_router_instance=self)
raise RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
@ -5120,6 +5263,19 @@ class Router:
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
)
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.ROUTER,
duration=_duration,
call_type="<routing_strategy>.async_get_available_deployments",
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
)
)
return deployment
except Exception as e:
traceback_exception = traceback.format_exc()
@ -5163,7 +5319,12 @@ class Router:
if isinstance(healthy_deployments, dict):
return healthy_deployments
cooldown_deployments = _get_cooldown_deployments(litellm_router_instance=self)
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
request_kwargs
)
cooldown_deployments = _get_cooldown_deployments(
litellm_router_instance=self, parent_otel_span=parent_otel_span
)
healthy_deployments = self._filter_cooldown_deployments(
healthy_deployments=healthy_deployments,
cooldown_deployments=cooldown_deployments,
@ -5180,8 +5341,12 @@ class Router:
if len(healthy_deployments) == 0:
model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids)
_cooldown_list = _get_cooldown_deployments(litellm_router_instance=self)
_cooldown_time = self.cooldown_cache.get_min_cooldown(
model_ids=model_ids, parent_otel_span=parent_otel_span
)
_cooldown_list = _get_cooldown_deployments(
litellm_router_instance=self, parent_otel_span=parent_otel_span
)
raise RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
@ -5238,8 +5403,12 @@ class Router:
f"get_available_deployment for model: {model}, No deployment available"
)
model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids)
_cooldown_list = _get_cooldown_deployments(litellm_router_instance=self)
_cooldown_time = self.cooldown_cache.get_min_cooldown(
model_ids=model_ids, parent_otel_span=parent_otel_span
)
_cooldown_list = _get_cooldown_deployments(
litellm_router_instance=self, parent_otel_span=parent_otel_span
)
raise RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
@ -5278,7 +5447,9 @@ class Router:
healthy_deployments.remove(deployment)
return healthy_deployments
def _track_deployment_metrics(self, deployment, response=None):
def _track_deployment_metrics(
self, deployment, parent_otel_span: Optional[Span], response=None
):
"""
Tracks successful requests rpm usage.
"""
@ -5288,7 +5459,9 @@ class Router:
# update self.deployment_stats
if model_id is not None:
self._update_usage(model_id) # update in-memory cache for tracking
self._update_usage(
model_id, parent_otel_span
) # update in-memory cache for tracking
except Exception as e:
verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}")

View file

@ -3,7 +3,7 @@
import random
import traceback
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from pydantic import BaseModel
@ -11,6 +11,14 @@ import litellm
from litellm import ModelResponse, token_counter, verbose_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class LiteLLMBase(BaseModel):
@ -115,8 +123,13 @@ class LowestLatencyLoggingHandler(CustomLogger):
# ------------
# Update usage
# ------------
request_count_dict = self.router_cache.get_cache(key=latency_key) or {}
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
request_count_dict = (
self.router_cache.get_cache(
key=latency_key, parent_otel_span=parent_otel_span
)
or {}
)
if id not in request_count_dict:
request_count_dict[id] = {}
@ -213,7 +226,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
"""
latency_key = f"{model_group}_map"
request_count_dict = (
self.router_cache.get_cache(key=latency_key) or {}
await self.router_cache.async_get_cache(key=latency_key) or {}
)
if id not in request_count_dict:
@ -316,8 +329,15 @@ class LowestLatencyLoggingHandler(CustomLogger):
# ------------
# Update usage
# ------------
request_count_dict = self.router_cache.get_cache(key=latency_key) or {}
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
request_count_dict = (
await self.router_cache.async_get_cache(
key=latency_key,
parent_otel_span=parent_otel_span,
local_only=True,
)
or {}
)
if id not in request_count_dict:
request_count_dict[id] = {}
@ -379,26 +399,21 @@ class LowestLatencyLoggingHandler(CustomLogger):
)
pass
def get_available_deployments( # noqa: PLR0915
def _get_available_deployments( # noqa: PLR0915
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
request_kwargs: Optional[Dict] = None,
request_count_dict: Optional[Dict] = None,
):
"""
Returns a deployment with the lowest latency
"""
# get list of potential deployments
latency_key = f"{model_group}_map"
_latency_per_deployment = {}
request_count_dict = self.router_cache.get_cache(key=latency_key) or {}
"""Common logic for both sync and async get_available_deployments"""
# -----------------------
# Find lowest used model
# ----------------------
_latency_per_deployment = {}
lowest_latency = float("inf")
current_date = datetime.now().strftime("%Y-%m-%d")
@ -428,8 +443,8 @@ class LowestLatencyLoggingHandler(CustomLogger):
# randomly sample from all_deployments, incase all deployments have latency=0.0
_items = all_deployments.items()
all_deployments = random.sample(list(_items), len(_items))
all_deployments = dict(all_deployments)
_all_deployments = random.sample(list(_items), len(_items))
all_deployments = dict(_all_deployments)
### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits
potential_deployments = []
@ -525,3 +540,66 @@ class LowestLatencyLoggingHandler(CustomLogger):
"_latency_per_deployment"
] = _latency_per_deployment
return deployment
async def async_get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
request_kwargs: Optional[Dict] = None,
):
# get list of potential deployments
latency_key = f"{model_group}_map"
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
request_kwargs
)
request_count_dict = (
await self.router_cache.async_get_cache(
key=latency_key, parent_otel_span=parent_otel_span
)
or {}
)
return self._get_available_deployments(
model_group,
healthy_deployments,
messages,
input,
request_kwargs,
request_count_dict,
)
def get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
request_kwargs: Optional[Dict] = None,
):
"""
Returns a deployment with the lowest latency
"""
# get list of potential deployments
latency_key = f"{model_group}_map"
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
request_kwargs
)
request_count_dict = (
self.router_cache.get_cache(
key=latency_key, parent_otel_span=parent_otel_span
)
or {}
)
return self._get_available_deployments(
model_group,
healthy_deployments,
messages,
input,
request_kwargs,
request_count_dict,
)

View file

@ -2,7 +2,7 @@
# identifies lowest tpm deployment
import random
import traceback
from typing import Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
from pydantic import BaseModel
@ -12,9 +12,17 @@ from litellm import token_counter
from litellm._logging import verbose_logger, verbose_router_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.types.router import RouterErrors
from litellm.utils import get_utc_datetime, print_verbose
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class LiteLLMBase(BaseModel):
"""
@ -136,7 +144,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
raise e
return deployment # don't fail calls if eg. redis fails to connect
async def async_pre_call_check(self, deployment: Dict) -> Optional[Dict]:
async def async_pre_call_check(
self, deployment: Dict, parent_otel_span: Optional[Span]
) -> Optional[Dict]:
"""
Pre-call check + update model rpm
- Used inside semaphore
@ -192,7 +202,10 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
else:
# if local result below limit, check redis ## prevent unnecessary redis checks
result = await self.router_cache.async_increment_cache(
key=rpm_key, value=1, ttl=self.routing_args.ttl
key=rpm_key,
value=1,
ttl=self.routing_args.ttl,
parent_otel_span=parent_otel_span,
)
if result is not None and result > deployment_rpm:
raise litellm.RateLimitError(
@ -301,10 +314,13 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
# Update usage
# ------------
# update cache
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
## TPM
await self.router_cache.async_increment_cache(
key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl
key=tpm_key,
value=total_tokens,
ttl=self.routing_args.ttl,
parent_otel_span=parent_otel_span,
)
### TESTING ###
@ -547,6 +563,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
parent_otel_span: Optional[Span] = None,
):
"""
Returns a deployment with the lowest TPM/RPM usage.
@ -572,10 +589,10 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
rpm_keys.append(rpm_key)
tpm_values = self.router_cache.batch_get_cache(
keys=tpm_keys
keys=tpm_keys, parent_otel_span=parent_otel_span
) # [1, 2, None, ..]
rpm_values = self.router_cache.batch_get_cache(
keys=rpm_keys
keys=rpm_keys, parent_otel_span=parent_otel_span
) # [1, 2, None, ..]
deployment = self._common_checks_available_deployment(

View file

@ -4,11 +4,18 @@ Wrapper around router cache. Meant to handle model cooldown logic
import json
import time
from typing import List, Optional, Tuple, TypedDict
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict
from litellm import verbose_logger
from litellm.caching.caching import DualCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class CooldownCacheValue(TypedDict):
exception_received: str
@ -77,13 +84,18 @@ class CooldownCache:
raise e
async def async_get_active_cooldowns(
self, model_ids: List[str]
self, model_ids: List[str], parent_otel_span: Optional[Span]
) -> List[Tuple[str, CooldownCacheValue]]:
# Generate the keys for the deployments
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
# Retrieve the values for the keys using mget
results = await self.cache.async_batch_get_cache(keys=keys) or []
results = (
await self.cache.async_batch_get_cache(
keys=keys, parent_otel_span=parent_otel_span
)
or []
)
active_cooldowns = []
# Process the results
@ -95,13 +107,15 @@ class CooldownCache:
return active_cooldowns
def get_active_cooldowns(
self, model_ids: List[str]
self, model_ids: List[str], parent_otel_span: Optional[Span]
) -> List[Tuple[str, CooldownCacheValue]]:
# Generate the keys for the deployments
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
# Retrieve the values for the keys using mget
results = self.cache.batch_get_cache(keys=keys) or []
results = (
self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
or []
)
active_cooldowns = []
# Process the results
@ -112,14 +126,19 @@ class CooldownCache:
return active_cooldowns
def get_min_cooldown(self, model_ids: List[str]) -> float:
def get_min_cooldown(
self, model_ids: List[str], parent_otel_span: Optional[Span]
) -> float:
"""Return min cooldown time required for a group of model id's."""
# Generate the keys for the deployments
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
# Retrieve the values for the keys using mget
results = self.cache.batch_get_cache(keys=keys) or []
results = (
self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
or []
)
min_cooldown_time: Optional[float] = None
# Process the results

View file

@ -20,12 +20,15 @@ from .router_callbacks.track_deployment_metrics import (
)
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.router import Router as _Router
LitellmRouter = _Router
Span = _Span
else:
LitellmRouter = Any
Span = Any
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
)
@ -207,6 +210,7 @@ def _set_cooldown_deployments(
async def _async_get_cooldown_deployments(
litellm_router_instance: LitellmRouter,
parent_otel_span: Optional[Span],
) -> List[str]:
"""
Async implementation of '_get_cooldown_deployments'
@ -214,7 +218,8 @@ async def _async_get_cooldown_deployments(
model_ids = litellm_router_instance.get_model_ids()
cooldown_models = (
await litellm_router_instance.cooldown_cache.async_get_active_cooldowns(
model_ids=model_ids
model_ids=model_ids,
parent_otel_span=parent_otel_span,
)
)
@ -233,6 +238,7 @@ async def _async_get_cooldown_deployments(
async def _async_get_cooldown_deployments_with_debug_info(
litellm_router_instance: LitellmRouter,
parent_otel_span: Optional[Span],
) -> List[tuple]:
"""
Async implementation of '_get_cooldown_deployments'
@ -240,7 +246,7 @@ async def _async_get_cooldown_deployments_with_debug_info(
model_ids = litellm_router_instance.get_model_ids()
cooldown_models = (
await litellm_router_instance.cooldown_cache.async_get_active_cooldowns(
model_ids=model_ids
model_ids=model_ids, parent_otel_span=parent_otel_span
)
)
@ -248,7 +254,9 @@ async def _async_get_cooldown_deployments_with_debug_info(
return cooldown_models
def _get_cooldown_deployments(litellm_router_instance: LitellmRouter) -> List[str]:
def _get_cooldown_deployments(
litellm_router_instance: LitellmRouter, parent_otel_span: Optional[Span]
) -> List[str]:
"""
Get the list of models being cooled down for this minute
"""
@ -258,8 +266,9 @@ def _get_cooldown_deployments(litellm_router_instance: LitellmRouter) -> List[st
# Return cooldown models
# ----------------------
model_ids = litellm_router_instance.get_model_ids()
cooldown_models = litellm_router_instance.cooldown_cache.get_active_cooldowns(
model_ids=model_ids
model_ids=model_ids, parent_otel_span=parent_otel_span
)
cached_value_deployment_ids = []

View file

@ -14,6 +14,7 @@ class ServiceTypes(str, enum.Enum):
DB = "postgres"
BATCH_WRITE_TO_DB = "batch_write_to_db"
LITELLM = "self"
ROUTER = "router"
class ServiceLoggerPayload(BaseModel):

View file

@ -83,7 +83,9 @@ def test_dual_cache_batch_get_cache():
in_memory_cache.set_cache(key="test_value", value="hello world")
result = dual_cache.batch_get_cache(keys=["test_value", "test_value_2"])
result = dual_cache.batch_get_cache(
keys=["test_value", "test_value_2"], parent_otel_span=None
)
assert result[0] == "hello world"
assert result[1] == None

View file

@ -2447,11 +2447,11 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode):
if sync_mode:
cooldown_deployments = _get_cooldown_deployments(
litellm_router_instance=router
litellm_router_instance=router, parent_otel_span=None
)
else:
cooldown_deployments = await _async_get_cooldown_deployments(
litellm_router_instance=router
litellm_router_instance=router, parent_otel_span=None
)
print(
"Cooldown deployments - {}\n{}".format(

View file

@ -242,12 +242,12 @@ async def test_single_deployment_no_cooldowns_test_prod_mock_completion_calls():
pass
cooldown_list = await _async_get_cooldown_deployments(
litellm_router_instance=router
litellm_router_instance=router, parent_otel_span=None
)
assert len(cooldown_list) == 0
healthy_deployments, _ = await router._async_get_healthy_deployments(
model="gpt-3.5-turbo"
model="gpt-3.5-turbo", parent_otel_span=None
)
print("healthy_deployments: ", healthy_deployments)
@ -351,7 +351,7 @@ async def test_high_traffic_cooldowns_all_healthy_deployments():
print("model_stats: ", model_stats)
cooldown_list = await _async_get_cooldown_deployments(
litellm_router_instance=router
litellm_router_instance=router, parent_otel_span=None
)
assert len(cooldown_list) == 0
@ -449,7 +449,7 @@ async def test_high_traffic_cooldowns_one_bad_deployment():
print("model_stats: ", model_stats)
cooldown_list = await _async_get_cooldown_deployments(
litellm_router_instance=router
litellm_router_instance=router, parent_otel_span=None
)
assert len(cooldown_list) == 1
@ -550,7 +550,7 @@ async def test_high_traffic_cooldowns_one_rate_limited_deployment():
print("model_stats: ", model_stats)
cooldown_list = await _async_get_cooldown_deployments(
litellm_router_instance=router
litellm_router_instance=router, parent_otel_span=None
)
assert len(cooldown_list) == 1

View file

@ -440,12 +440,12 @@ def test_update_usage(model_list):
)
deployment_id = deployment["model_info"]["id"]
request_count = router._update_usage(
deployment_id=deployment_id,
deployment_id=deployment_id, parent_otel_span=None
)
assert request_count == 1
request_count = router._update_usage(
deployment_id=deployment_id,
deployment_id=deployment_id, parent_otel_span=None
)
assert request_count == 2
@ -482,7 +482,9 @@ def test_should_raise_content_policy_error(model_list, finish_reason, expected_e
def test_get_healthy_deployments(model_list):
"""Test if the '_get_healthy_deployments' function is working correctly"""
router = Router(model_list=model_list)
deployments = router._get_healthy_deployments(model="gpt-3.5-turbo")
deployments = router._get_healthy_deployments(
model="gpt-3.5-turbo", parent_otel_span=None
)
assert len(deployments) > 0
@ -756,6 +758,7 @@ def test_track_deployment_metrics(model_list):
model="gpt-3.5-turbo",
usage={"total_tokens": 100},
),
parent_otel_span=None,
)