redis otel tracing + async support for latency routing (#6452)

* docs(exception_mapping.md): add missing exception types

Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183

* fix(main.py): register custom model pricing with specific key

Ensure custom model pricing is registered to the specific model+provider key combination

* test: make testing more robust for custom pricing

* fix(redis_cache.py): instrument otel logging for sync redis calls

ensures complete coverage for all redis cache calls

* refactor: pass parent_otel_span for redis caching calls in router

allows for more observability into what calls are causing latency issues

* test: update tests with new params

* refactor: ensure e2e otel tracing for router

* refactor(router.py): add more otel tracing acrosss router

catch all latency issues for router requests

* fix: fix linting error

* fix(router.py): fix linting error

* fix: fix test

* test: fix tests

* fix(dual_cache.py): pass ttl to redis cache

* fix: fix param
This commit is contained in:
Krish Dholakia 2024-10-28 21:52:12 -07:00 committed by GitHub
parent d9e7818e6b
commit 4f8a3fd4cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 559 additions and 147 deletions

View file

@ -44,8 +44,7 @@ class ServiceLogging(CustomLogger):
""" """
Handles both sync and async monitoring by checking for existing event loop. Handles both sync and async monitoring by checking for existing event loop.
""" """
# if service == ServiceTypes.REDIS:
# print(f"SYNC service: {service}, call_type: {call_type}")
if self.mock_testing: if self.mock_testing:
self.mock_testing_sync_success_hook += 1 self.mock_testing_sync_success_hook += 1
@ -112,8 +111,7 @@ class ServiceLogging(CustomLogger):
""" """
- For counting if the redis, postgres call is successful - For counting if the redis, postgres call is successful
""" """
# if service == ServiceTypes.REDIS:
# print(f"service: {service}, call_type: {call_type}")
if self.mock_testing: if self.mock_testing:
self.mock_testing_async_success_hook += 1 self.mock_testing_async_success_hook += 1

View file

@ -8,7 +8,14 @@ Has 4 methods:
- async_get_cache - async_get_cache
""" """
from typing import Optional from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class BaseCache: class BaseCache:

View file

@ -1,10 +1,17 @@
import json import json
from typing import Optional from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import print_verbose from litellm._logging import print_verbose
from .base_cache import BaseCache from .base_cache import BaseCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class DiskCache(BaseCache): class DiskCache(BaseCache):
def __init__(self, disk_cache_dir: Optional[str] = None): def __init__(self, disk_cache_dir: Optional[str] = None):

View file

@ -9,7 +9,7 @@ Has 4 primary methods:
""" """
import traceback import traceback
from typing import List, Optional from typing import TYPE_CHECKING, Any, List, Optional
import litellm import litellm
from litellm._logging import print_verbose, verbose_logger from litellm._logging import print_verbose, verbose_logger
@ -18,6 +18,13 @@ from .base_cache import BaseCache
from .in_memory_cache import InMemoryCache from .in_memory_cache import InMemoryCache
from .redis_cache import RedisCache from .redis_cache import RedisCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class DualCache(BaseCache): class DualCache(BaseCache):
""" """
@ -90,7 +97,13 @@ class DualCache(BaseCache):
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
raise e raise e
def get_cache(self, key, local_only: bool = False, **kwargs): def get_cache(
self,
key,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
):
# Try to fetch from in-memory cache first # Try to fetch from in-memory cache first
try: try:
result = None result = None
@ -102,7 +115,9 @@ class DualCache(BaseCache):
if result is None and self.redis_cache is not None and local_only is False: if result is None and self.redis_cache is not None and local_only is False:
# If not found in in-memory cache, try fetching from Redis # If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.get_cache(key, **kwargs) redis_result = self.redis_cache.get_cache(
key, parent_otel_span=parent_otel_span
)
if redis_result is not None: if redis_result is not None:
# Update in-memory cache with the value from Redis # Update in-memory cache with the value from Redis
@ -115,7 +130,13 @@ class DualCache(BaseCache):
except Exception: except Exception:
verbose_logger.error(traceback.format_exc()) verbose_logger.error(traceback.format_exc())
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs): def batch_get_cache(
self,
keys: list,
parent_otel_span: Optional[Span],
local_only: bool = False,
**kwargs,
):
try: try:
result = [None for _ in range(len(keys))] result = [None for _ in range(len(keys))]
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
@ -133,7 +154,9 @@ class DualCache(BaseCache):
key for key, value in zip(keys, result) if value is None key for key, value in zip(keys, result) if value is None
] ]
# If not found in in-memory cache, try fetching from Redis # If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.batch_get_cache(sublist_keys, **kwargs) redis_result = self.redis_cache.batch_get_cache(
sublist_keys, parent_otel_span=parent_otel_span
)
if redis_result is not None: if redis_result is not None:
# Update in-memory cache with the value from Redis # Update in-memory cache with the value from Redis
for key in redis_result: for key in redis_result:
@ -147,7 +170,13 @@ class DualCache(BaseCache):
except Exception: except Exception:
verbose_logger.error(traceback.format_exc()) verbose_logger.error(traceback.format_exc())
async def async_get_cache(self, key, local_only: bool = False, **kwargs): async def async_get_cache(
self,
key,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
):
# Try to fetch from in-memory cache first # Try to fetch from in-memory cache first
try: try:
print_verbose( print_verbose(
@ -165,7 +194,9 @@ class DualCache(BaseCache):
if result is None and self.redis_cache is not None and local_only is False: if result is None and self.redis_cache is not None and local_only is False:
# If not found in in-memory cache, try fetching from Redis # If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_get_cache(key, **kwargs) redis_result = await self.redis_cache.async_get_cache(
key, parent_otel_span=parent_otel_span
)
if redis_result is not None: if redis_result is not None:
# Update in-memory cache with the value from Redis # Update in-memory cache with the value from Redis
@ -181,7 +212,11 @@ class DualCache(BaseCache):
verbose_logger.error(traceback.format_exc()) verbose_logger.error(traceback.format_exc())
async def async_batch_get_cache( async def async_batch_get_cache(
self, keys: list, local_only: bool = False, **kwargs self,
keys: list,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
): ):
try: try:
result = [None for _ in range(len(keys))] result = [None for _ in range(len(keys))]
@ -202,7 +237,7 @@ class DualCache(BaseCache):
] ]
# If not found in in-memory cache, try fetching from Redis # If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_batch_get_cache( redis_result = await self.redis_cache.async_batch_get_cache(
sublist_keys, **kwargs sublist_keys, parent_otel_span=parent_otel_span
) )
if redis_result is not None: if redis_result is not None:
@ -260,7 +295,12 @@ class DualCache(BaseCache):
) )
async def async_increment_cache( async def async_increment_cache(
self, key, value: float, local_only: bool = False, **kwargs self,
key,
value: float,
parent_otel_span: Optional[Span],
local_only: bool = False,
**kwargs,
) -> float: ) -> float:
""" """
Key - the key in cache Key - the key in cache
@ -277,7 +317,12 @@ class DualCache(BaseCache):
) )
if self.redis_cache is not None and local_only is False: if self.redis_cache is not None and local_only is False:
result = await self.redis_cache.async_increment(key, value, **kwargs) result = await self.redis_cache.async_increment(
key,
value,
parent_otel_span=parent_otel_span,
ttl=kwargs.get("ttl", None),
)
return result return result
except Exception as e: except Exception as e:

View file

@ -13,6 +13,7 @@ import asyncio
import inspect import inspect
import json import json
import time import time
import traceback
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING, Any, List, Optional, Tuple from typing import TYPE_CHECKING, Any, List, Optional, Tuple
@ -25,14 +26,17 @@ from litellm.types.utils import all_litellm_params
from .base_cache import BaseCache from .base_cache import BaseCache
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from redis.asyncio import Redis from redis.asyncio import Redis
from redis.asyncio.client import Pipeline from redis.asyncio.client import Pipeline
pipeline = Pipeline pipeline = Pipeline
async_redis_client = Redis async_redis_client = Redis
Span = _Span
else: else:
pipeline = Any pipeline = Any
async_redis_client = Any async_redis_client = Any
Span = Any
class RedisCache(BaseCache): class RedisCache(BaseCache):
@ -524,7 +528,11 @@ class RedisCache(BaseCache):
await self.flush_cache_buffer() # logging done in here await self.flush_cache_buffer() # logging done in here
async def async_increment( async def async_increment(
self, key, value: float, ttl: Optional[int] = None, **kwargs self,
key,
value: float,
ttl: Optional[int] = None,
parent_otel_span: Optional[Span] = None,
) -> float: ) -> float:
from redis.asyncio import Redis from redis.asyncio import Redis
@ -552,7 +560,7 @@ class RedisCache(BaseCache):
call_type="async_increment", call_type="async_increment",
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), parent_otel_span=parent_otel_span,
) )
) )
return result return result
@ -568,7 +576,7 @@ class RedisCache(BaseCache):
call_type="async_increment", call_type="async_increment",
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), parent_otel_span=parent_otel_span,
) )
) )
verbose_logger.error( verbose_logger.error(
@ -601,7 +609,7 @@ class RedisCache(BaseCache):
cached_response = ast.literal_eval(cached_response) cached_response = ast.literal_eval(cached_response)
return cached_response return cached_response
def get_cache(self, key, **kwargs): def get_cache(self, key, parent_otel_span: Optional[Span] = None, **kwargs):
try: try:
key = self.check_and_fix_namespace(key=key) key = self.check_and_fix_namespace(key=key)
print_verbose(f"Get Redis Cache: key: {key}") print_verbose(f"Get Redis Cache: key: {key}")
@ -615,6 +623,7 @@ class RedisCache(BaseCache):
call_type="get_cache", call_type="get_cache",
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
parent_otel_span=parent_otel_span,
) )
print_verbose( print_verbose(
f"Got Redis Cache: key: {key}, cached_response {cached_response}" f"Got Redis Cache: key: {key}, cached_response {cached_response}"
@ -626,11 +635,12 @@ class RedisCache(BaseCache):
"litellm.caching.caching: get() - Got exception from REDIS: ", e "litellm.caching.caching: get() - Got exception from REDIS: ", e
) )
def batch_get_cache(self, key_list) -> dict: def batch_get_cache(self, key_list, parent_otel_span: Optional[Span]) -> dict:
""" """
Use Redis for bulk read operations Use Redis for bulk read operations
""" """
key_value_dict = {} key_value_dict = {}
try: try:
_keys = [] _keys = []
for cache_key in key_list: for cache_key in key_list:
@ -646,6 +656,7 @@ class RedisCache(BaseCache):
call_type="batch_get_cache", call_type="batch_get_cache",
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
parent_otel_span=parent_otel_span,
) )
# Associate the results back with their keys. # Associate the results back with their keys.
@ -662,7 +673,9 @@ class RedisCache(BaseCache):
print_verbose(f"Error occurred in pipeline read - {str(e)}") print_verbose(f"Error occurred in pipeline read - {str(e)}")
return key_value_dict return key_value_dict
async def async_get_cache(self, key, **kwargs): async def async_get_cache(
self, key, parent_otel_span: Optional[Span] = None, **kwargs
):
from redis.asyncio import Redis from redis.asyncio import Redis
_redis_client: Redis = self.init_async_client() # type: ignore _redis_client: Redis = self.init_async_client() # type: ignore
@ -686,7 +699,7 @@ class RedisCache(BaseCache):
call_type="async_get_cache", call_type="async_get_cache",
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), parent_otel_span=parent_otel_span,
event_metadata={"key": key}, event_metadata={"key": key},
) )
) )
@ -703,7 +716,7 @@ class RedisCache(BaseCache):
call_type="async_get_cache", call_type="async_get_cache",
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), parent_otel_span=parent_otel_span,
event_metadata={"key": key}, event_metadata={"key": key},
) )
) )
@ -712,10 +725,13 @@ class RedisCache(BaseCache):
f"litellm.caching.caching: async get() - Got exception from REDIS: {str(e)}" f"litellm.caching.caching: async get() - Got exception from REDIS: {str(e)}"
) )
async def async_batch_get_cache(self, key_list) -> dict: async def async_batch_get_cache(
self, key_list: List[str], parent_otel_span: Optional[Span] = None
) -> dict:
""" """
Use Redis for bulk read operations Use Redis for bulk read operations
""" """
_redis_client = await self.init_async_client() _redis_client = await self.init_async_client()
key_value_dict = {} key_value_dict = {}
start_time = time.time() start_time = time.time()
@ -737,6 +753,7 @@ class RedisCache(BaseCache):
call_type="async_batch_get_cache", call_type="async_batch_get_cache",
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
parent_otel_span=parent_otel_span,
) )
) )
@ -764,6 +781,7 @@ class RedisCache(BaseCache):
call_type="async_batch_get_cache", call_type="async_batch_get_cache",
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
parent_otel_span=parent_otel_span,
) )
) )
print_verbose(f"Error occurred in pipeline read - {str(e)}") print_verbose(f"Error occurred in pipeline read - {str(e)}")

View file

@ -268,6 +268,7 @@ class SlackAlerting(CustomBatchLogger):
SlackAlertingCacheKeys.failed_requests_key.value, SlackAlertingCacheKeys.failed_requests_key.value,
), ),
value=1, value=1,
parent_otel_span=None, # no attached request, this is a background operation
) )
return_val += 1 return_val += 1
@ -279,6 +280,7 @@ class SlackAlerting(CustomBatchLogger):
deployment_metrics.id, SlackAlertingCacheKeys.latency_key.value deployment_metrics.id, SlackAlertingCacheKeys.latency_key.value
), ),
value=deployment_metrics.latency_per_output_token, value=deployment_metrics.latency_per_output_token,
parent_otel_span=None, # no attached request, this is a background operation
) )
return_val += 1 return_val += 1
@ -1518,7 +1520,8 @@ Model Info:
report_sent_bool = False report_sent_bool = False
report_sent = await self.internal_usage_cache.async_get_cache( report_sent = await self.internal_usage_cache.async_get_cache(
key=SlackAlertingCacheKeys.report_sent_key.value key=SlackAlertingCacheKeys.report_sent_key.value,
parent_otel_span=None,
) # None | float ) # None | float
current_time = time.time() current_time = time.time()

View file

@ -3,7 +3,7 @@
import os import os
import traceback import traceback
from datetime import datetime as datetimeObj from datetime import datetime as datetimeObj
from typing import Any, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, Union
import dotenv import dotenv
from pydantic import BaseModel from pydantic import BaseModel
@ -21,6 +21,13 @@ from litellm.types.utils import (
StandardLoggingPayload, StandardLoggingPayload,
) )
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
# Class variables or attributes # Class variables or attributes
@ -62,7 +69,9 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks). Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks).
""" """
async def async_pre_call_check(self, deployment: dict) -> Optional[dict]: async def async_pre_call_check(
self, deployment: dict, parent_otel_span: Optional[Span]
) -> Optional[dict]:
pass pass
def pre_call_check(self, deployment: dict) -> Optional[dict]: def pre_call_check(self, deployment: dict) -> Optional[dict]:

View file

@ -8,7 +8,12 @@ import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.types.services import ServiceLoggerPayload from litellm.types.services import ServiceLoggerPayload
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
StandardLoggingPayload,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
@ -136,12 +141,12 @@ class OpenTelemetry(CustomLogger):
_end_time_ns = 0 _end_time_ns = 0
if isinstance(start_time, float): if isinstance(start_time, float):
_start_time_ns = int(int(start_time) * 1e9) _start_time_ns = int(start_time * 1e9)
else: else:
_start_time_ns = self._to_ns(start_time) _start_time_ns = self._to_ns(start_time)
if isinstance(end_time, float): if isinstance(end_time, float):
_end_time_ns = int(int(end_time) * 1e9) _end_time_ns = int(end_time * 1e9)
else: else:
_end_time_ns = self._to_ns(end_time) _end_time_ns = self._to_ns(end_time)
@ -276,6 +281,21 @@ class OpenTelemetry(CustomLogger):
# End Parent OTEL Sspan # End Parent OTEL Sspan
parent_otel_span.end(end_time=self._to_ns(datetime.now())) parent_otel_span.end(end_time=self._to_ns(datetime.now()))
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
):
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
parent_otel_span = user_api_key_dict.parent_otel_span
if parent_otel_span is not None:
parent_otel_span.set_status(Status(StatusCode.OK))
# End Parent OTEL Sspan
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
def _handle_sucess(self, kwargs, response_obj, start_time, end_time): def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode from opentelemetry.trace import Status, StatusCode
@ -314,8 +334,8 @@ class OpenTelemetry(CustomLogger):
span.end(end_time=self._to_ns(end_time)) span.end(end_time=self._to_ns(end_time))
if parent_otel_span is not None: # if parent_otel_span is not None:
parent_otel_span.end(end_time=self._to_ns(datetime.now())) # parent_otel_span.end(end_time=self._to_ns(datetime.now()))
def _handle_failure(self, kwargs, response_obj, start_time, end_time): def _handle_failure(self, kwargs, response_obj, start_time, end_time):
from opentelemetry.trace import Status, StatusCode from opentelemetry.trace import Status, StatusCode
@ -808,12 +828,12 @@ class OpenTelemetry(CustomLogger):
end_time = logging_payload.end_time end_time = logging_payload.end_time
if isinstance(start_time, float): if isinstance(start_time, float):
_start_time_ns = int(int(start_time) * 1e9) _start_time_ns = int(start_time * 1e9)
else: else:
_start_time_ns = self._to_ns(start_time) _start_time_ns = self._to_ns(start_time)
if isinstance(end_time, float): if isinstance(end_time, float):
_end_time_ns = int(int(end_time) * 1e9) _end_time_ns = int(end_time * 1e9)
else: else:
_end_time_ns = self._to_ns(end_time) _end_time_ns = self._to_ns(end_time)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -25,4 +25,4 @@ router_settings:
ttl: 300 ttl: 300
redis_host: os.environ/REDIS_HOST redis_host: os.environ/REDIS_HOST
redis_port: os.environ/REDIS_PORT redis_port: os.environ/REDIS_PORT
redis_password: os.environ/REDIS_PASSWORD redis_password: os.environ/REDIS_PASSWORD

View file

@ -260,6 +260,7 @@ async def user_api_key_auth( # noqa: PLR0915
headers=request.headers headers=request.headers
), ),
) )
### USER-DEFINED AUTH FUNCTION ### ### USER-DEFINED AUTH FUNCTION ###
if user_custom_auth is not None: if user_custom_auth is not None:
response = await user_custom_auth(request=request, api_key=api_key) # type: ignore response = await user_custom_auth(request=request, api_key=api_key) # type: ignore

View file

@ -28,7 +28,9 @@ class _PROXY_MaxBudgetLimiter(CustomLogger):
try: try:
self.print_verbose("Inside Max Budget Limiter Pre-Call Hook") self.print_verbose("Inside Max Budget Limiter Pre-Call Hook")
cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id" cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"
user_row = cache.get_cache(cache_key) user_row = await cache.async_get_cache(
cache_key, parent_otel_span=user_api_key_dict.parent_otel_span
)
if user_row is None: # value not yet cached if user_row is None: # value not yet cached
return return
max_budget = user_row["max_budget"] max_budget = user_row["max_budget"]

View file

@ -235,7 +235,7 @@ class InternalUsageCache:
return await self.dual_cache.async_get_cache( return await self.dual_cache.async_get_cache(
key=key, key=key,
local_only=local_only, local_only=local_only,
litellm_parent_otel_span=litellm_parent_otel_span, parent_otel_span=litellm_parent_otel_span,
**kwargs, **kwargs,
) )
@ -281,7 +281,7 @@ class InternalUsageCache:
key=key, key=key,
value=value, value=value,
local_only=local_only, local_only=local_only,
litellm_parent_otel_span=litellm_parent_otel_span, parent_otel_span=litellm_parent_otel_span,
**kwargs, **kwargs,
) )
@ -367,7 +367,10 @@ class ProxyLogging:
llm_router=llm_router llm_router=llm_router
) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made ) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
if "daily_reports" in self.slack_alerting_instance.alert_types: if (
self.slack_alerting_instance is not None
and "daily_reports" in self.slack_alerting_instance.alert_types
):
asyncio.create_task( asyncio.create_task(
self.slack_alerting_instance._run_scheduled_daily_report( self.slack_alerting_instance._run_scheduled_daily_report(
llm_router=llm_router llm_router=llm_router

View file

@ -25,6 +25,7 @@ import uuid
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Callable, Callable,
Dict, Dict,
@ -50,6 +51,7 @@ from litellm._logging import verbose_router_logger
from litellm.assistants.main import AssistantDeleted from litellm.assistants.main import AssistantDeleted
from litellm.caching.caching import DualCache, InMemoryCache, RedisCache from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
@ -124,6 +126,7 @@ from litellm.types.router import (
updateDeployment, updateDeployment,
updateLiteLLMParams, updateLiteLLMParams,
) )
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from litellm.types.utils import OPENAI_RESPONSE_HEADERS from litellm.types.utils import OPENAI_RESPONSE_HEADERS
from litellm.types.utils import ModelInfo as ModelMapInfo from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.utils import ( from litellm.utils import (
@ -140,6 +143,13 @@ from litellm.utils import (
from .router_utils.pattern_match_deployments import PatternMatchRouter from .router_utils.pattern_match_deployments import PatternMatchRouter
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class RoutingArgs(enum.Enum): class RoutingArgs(enum.Enum):
ttl = 60 # 1min (RPM/TPM expire key) ttl = 60 # 1min (RPM/TPM expire key)
@ -293,6 +303,8 @@ class Router:
``` ```
""" """
from litellm._service_logger import ServiceLogging
if semaphore: if semaphore:
self.semaphore = semaphore self.semaphore = semaphore
self.set_verbose = set_verbose self.set_verbose = set_verbose
@ -494,7 +506,7 @@ class Router:
f"Routing context window fallbacks: {self.context_window_fallbacks}\n\n" f"Routing context window fallbacks: {self.context_window_fallbacks}\n\n"
f"Router Redis Caching={self.cache.redis_cache}\n" f"Router Redis Caching={self.cache.redis_cache}\n"
) )
self.service_logger_obj = ServiceLogging()
self.routing_strategy_args = routing_strategy_args self.routing_strategy_args = routing_strategy_args
self.retry_policy: Optional[RetryPolicy] = None self.retry_policy: Optional[RetryPolicy] = None
if retry_policy is not None: if retry_policy is not None:
@ -762,10 +774,23 @@ class Router:
request_priority = kwargs.get("priority") or self.default_priority request_priority = kwargs.get("priority") or self.default_priority
start_time = time.time()
if request_priority is not None and isinstance(request_priority, int): if request_priority is not None and isinstance(request_priority, int):
response = await self.schedule_acompletion(**kwargs) response = await self.schedule_acompletion(**kwargs)
else: else:
response = await self.async_function_with_fallbacks(**kwargs) response = await self.async_function_with_fallbacks(**kwargs)
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.ROUTER,
duration=_duration,
call_type="acompletion",
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
)
)
return response return response
except Exception as e: except Exception as e:
@ -793,15 +818,32 @@ class Router:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _acompletion()- model: {model}; kwargs: {kwargs}" f"Inside _acompletion()- model: {model}; kwargs: {kwargs}"
) )
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
start_time = time.time()
deployment = await self.async_get_available_deployment( deployment = await self.async_get_available_deployment(
model=model, model=model,
messages=messages, messages=messages,
specific_deployment=kwargs.pop("specific_deployment", None), specific_deployment=kwargs.pop("specific_deployment", None),
request_kwargs=kwargs, request_kwargs=kwargs,
) )
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.ROUTER,
duration=_duration,
call_type="async_get_available_deployment",
start_time=start_time,
end_time=end_time,
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
)
)
# debug how often this deployment picked # debug how often this deployment picked
self._track_deployment_metrics(deployment=deployment)
self._track_deployment_metrics(
deployment=deployment, parent_otel_span=parent_otel_span
)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
@ -846,12 +888,16 @@ class Router:
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
""" """
await self.async_routing_strategy_pre_call_checks( await self.async_routing_strategy_pre_call_checks(
deployment=deployment, logging_obj=logging_obj deployment=deployment,
logging_obj=logging_obj,
parent_otel_span=parent_otel_span,
) )
response = await _response response = await _response
else: else:
await self.async_routing_strategy_pre_call_checks( await self.async_routing_strategy_pre_call_checks(
deployment=deployment, logging_obj=logging_obj deployment=deployment,
logging_obj=logging_obj,
parent_otel_span=parent_otel_span,
) )
response = await _response response = await _response
@ -872,7 +918,11 @@ class Router:
f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m" f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m"
) )
# debug how often this deployment picked # debug how often this deployment picked
self._track_deployment_metrics(deployment=deployment, response=response) self._track_deployment_metrics(
deployment=deployment,
response=response,
parent_otel_span=parent_otel_span,
)
return response return response
except Exception as e: except Exception as e:
@ -1212,6 +1262,7 @@ class Router:
stream=False, stream=False,
**kwargs, **kwargs,
): ):
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
### FLOW ITEM ### ### FLOW ITEM ###
_request_id = str(uuid.uuid4()) _request_id = str(uuid.uuid4())
item = FlowItem( item = FlowItem(
@ -1232,7 +1283,7 @@ class Router:
while curr_time < end_time: while curr_time < end_time:
_healthy_deployments, _ = await self._async_get_healthy_deployments( _healthy_deployments, _ = await self._async_get_healthy_deployments(
model=model model=model, parent_otel_span=parent_otel_span
) )
make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
id=item.request_id, id=item.request_id,
@ -1353,6 +1404,7 @@ class Router:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}" f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
) )
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
deployment = await self.async_get_available_deployment( deployment = await self.async_get_available_deployment(
model=model, model=model,
messages=[{"role": "user", "content": "prompt"}], messages=[{"role": "user", "content": "prompt"}],
@ -1395,11 +1447,13 @@ class Router:
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
""" """
await self.async_routing_strategy_pre_call_checks( await self.async_routing_strategy_pre_call_checks(
deployment=deployment deployment=deployment, parent_otel_span=parent_otel_span
) )
response = await response response = await response
else: else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment) await self.async_routing_strategy_pre_call_checks(
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response response = await response
self.success_calls[model_name] += 1 self.success_calls[model_name] += 1
@ -1465,6 +1519,7 @@ class Router:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _atranscription()- model: {model}; kwargs: {kwargs}" f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
) )
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
deployment = await self.async_get_available_deployment( deployment = await self.async_get_available_deployment(
model=model, model=model,
messages=[{"role": "user", "content": "prompt"}], messages=[{"role": "user", "content": "prompt"}],
@ -1505,11 +1560,13 @@ class Router:
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
""" """
await self.async_routing_strategy_pre_call_checks( await self.async_routing_strategy_pre_call_checks(
deployment=deployment deployment=deployment, parent_otel_span=parent_otel_span
) )
response = await response response = await response
else: else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment) await self.async_routing_strategy_pre_call_checks(
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response response = await response
self.success_calls[model_name] += 1 self.success_calls[model_name] += 1
@ -1861,6 +1918,7 @@ class Router:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}" f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
) )
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
deployment = await self.async_get_available_deployment( deployment = await self.async_get_available_deployment(
model=model, model=model,
messages=[{"role": "user", "content": prompt}], messages=[{"role": "user", "content": prompt}],
@ -1903,11 +1961,13 @@ class Router:
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
""" """
await self.async_routing_strategy_pre_call_checks( await self.async_routing_strategy_pre_call_checks(
deployment=deployment deployment=deployment, parent_otel_span=parent_otel_span
) )
response = await response response = await response
else: else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment) await self.async_routing_strategy_pre_call_checks(
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response response = await response
self.success_calls[model_name] += 1 self.success_calls[model_name] += 1
@ -1958,6 +2018,7 @@ class Router:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _aadapter_completion()- model: {model}; kwargs: {kwargs}" f"Inside _aadapter_completion()- model: {model}; kwargs: {kwargs}"
) )
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
deployment = await self.async_get_available_deployment( deployment = await self.async_get_available_deployment(
model=model, model=model,
messages=[{"role": "user", "content": "default text"}], messages=[{"role": "user", "content": "default text"}],
@ -2000,11 +2061,13 @@ class Router:
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
""" """
await self.async_routing_strategy_pre_call_checks( await self.async_routing_strategy_pre_call_checks(
deployment=deployment deployment=deployment, parent_otel_span=parent_otel_span
) )
response = await response # type: ignore response = await response # type: ignore
else: else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment) await self.async_routing_strategy_pre_call_checks(
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response # type: ignore response = await response # type: ignore
self.success_calls[model_name] += 1 self.success_calls[model_name] += 1
@ -2128,6 +2191,7 @@ class Router:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _aembedding()- model: {model}; kwargs: {kwargs}" f"Inside _aembedding()- model: {model}; kwargs: {kwargs}"
) )
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
deployment = await self.async_get_available_deployment( deployment = await self.async_get_available_deployment(
model=model, model=model,
input=input, input=input,
@ -2168,11 +2232,13 @@ class Router:
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
""" """
await self.async_routing_strategy_pre_call_checks( await self.async_routing_strategy_pre_call_checks(
deployment=deployment deployment=deployment, parent_otel_span=parent_otel_span
) )
response = await response response = await response
else: else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment) await self.async_routing_strategy_pre_call_checks(
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response response = await response
self.success_calls[model_name] += 1 self.success_calls[model_name] += 1
@ -2223,6 +2289,7 @@ class Router:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}" f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}"
) )
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
deployment = await self.async_get_available_deployment( deployment = await self.async_get_available_deployment(
model=model, model=model,
messages=[{"role": "user", "content": "files-api-fake-text"}], messages=[{"role": "user", "content": "files-api-fake-text"}],
@ -2273,11 +2340,13 @@ class Router:
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
""" """
await self.async_routing_strategy_pre_call_checks( await self.async_routing_strategy_pre_call_checks(
deployment=deployment deployment=deployment, parent_otel_span=parent_otel_span
) )
response = await response # type: ignore response = await response # type: ignore
else: else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment) await self.async_routing_strategy_pre_call_checks(
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response # type: ignore response = await response # type: ignore
self.success_calls[model_name] += 1 self.success_calls[model_name] += 1
@ -2327,6 +2396,7 @@ class Router:
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside _acreate_batch()- model: {model}; kwargs: {kwargs}" f"Inside _acreate_batch()- model: {model}; kwargs: {kwargs}"
) )
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
deployment = await self.async_get_available_deployment( deployment = await self.async_get_available_deployment(
model=model, model=model,
messages=[{"role": "user", "content": "files-api-fake-text"}], messages=[{"role": "user", "content": "files-api-fake-text"}],
@ -2389,11 +2459,13 @@ class Router:
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
""" """
await self.async_routing_strategy_pre_call_checks( await self.async_routing_strategy_pre_call_checks(
deployment=deployment deployment=deployment, parent_otel_span=parent_otel_span
) )
response = await response # type: ignore response = await response # type: ignore
else: else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment) await self.async_routing_strategy_pre_call_checks(
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response # type: ignore response = await response # type: ignore
self.success_calls[model_name] += 1 self.success_calls[model_name] += 1
@ -2702,12 +2774,14 @@ class Router:
) )
return response return response
except Exception as new_exception: except Exception as new_exception:
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
verbose_router_logger.error( verbose_router_logger.error(
"litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format( "litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format(
str(new_exception), str(new_exception),
traceback.format_exc(), traceback.format_exc(),
await _async_get_cooldown_deployments_with_debug_info( await _async_get_cooldown_deployments_with_debug_info(
litellm_router_instance=self litellm_router_instance=self,
parent_otel_span=parent_otel_span,
), ),
) )
) )
@ -2779,12 +2853,13 @@ class Router:
Context_Policy_Fallbacks={content_policy_fallbacks}", Context_Policy_Fallbacks={content_policy_fallbacks}",
) )
async def async_function_with_retries(self, *args, **kwargs): async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915
verbose_router_logger.debug( verbose_router_logger.debug(
f"Inside async function with retries: args - {args}; kwargs - {kwargs}" f"Inside async function with retries: args - {args}; kwargs - {kwargs}"
) )
original_function = kwargs.pop("original_function") original_function = kwargs.pop("original_function")
fallbacks = kwargs.pop("fallbacks", self.fallbacks) fallbacks = kwargs.pop("fallbacks", self.fallbacks)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
context_window_fallbacks = kwargs.pop( context_window_fallbacks = kwargs.pop(
"context_window_fallbacks", self.context_window_fallbacks "context_window_fallbacks", self.context_window_fallbacks
) )
@ -2822,6 +2897,7 @@ class Router:
_healthy_deployments, _all_deployments = ( _healthy_deployments, _all_deployments = (
await self._async_get_healthy_deployments( await self._async_get_healthy_deployments(
model=kwargs.get("model") or "", model=kwargs.get("model") or "",
parent_otel_span=parent_otel_span,
) )
) )
@ -2879,6 +2955,7 @@ class Router:
_healthy_deployments, _ = ( _healthy_deployments, _ = (
await self._async_get_healthy_deployments( await self._async_get_healthy_deployments(
model=_model, model=_model,
parent_otel_span=parent_otel_span,
) )
) )
else: else:
@ -3217,8 +3294,10 @@ class Router:
if _model is None: if _model is None:
raise e # re-raise error, if model can't be determined for loadbalancing raise e # re-raise error, if model can't be determined for loadbalancing
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
_healthy_deployments, _all_deployments = self._get_healthy_deployments( _healthy_deployments, _all_deployments = self._get_healthy_deployments(
model=_model, model=_model,
parent_otel_span=parent_otel_span,
) )
# raises an exception if this error should not be retries # raises an exception if this error should not be retries
@ -3260,8 +3339,10 @@ class Router:
if _model is None: if _model is None:
raise e # re-raise error, if model can't be determined for loadbalancing raise e # re-raise error, if model can't be determined for loadbalancing
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
_healthy_deployments, _ = self._get_healthy_deployments( _healthy_deployments, _ = self._get_healthy_deployments(
model=_model, model=_model,
parent_otel_span=parent_otel_span,
) )
remaining_retries = num_retries - current_attempt remaining_retries = num_retries - current_attempt
_timeout = self._time_to_sleep_before_retry( _timeout = self._time_to_sleep_before_retry(
@ -3323,9 +3404,13 @@ class Router:
# ------------ # ------------
# update cache # update cache
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
## TPM ## TPM
await self.cache.async_increment_cache( await self.cache.async_increment_cache(
key=tpm_key, value=total_tokens, ttl=RoutingArgs.ttl.value key=tpm_key,
value=total_tokens,
parent_otel_span=parent_otel_span,
ttl=RoutingArgs.ttl.value,
) )
increment_deployment_successes_for_current_minute( increment_deployment_successes_for_current_minute(
@ -3474,7 +3559,9 @@ class Router:
except Exception as e: except Exception as e:
raise e raise e
def _update_usage(self, deployment_id: str) -> int: def _update_usage(
self, deployment_id: str, parent_otel_span: Optional[Span]
) -> int:
""" """
Update deployment rpm for that minute Update deployment rpm for that minute
@ -3483,7 +3570,9 @@ class Router:
""" """
rpm_key = deployment_id rpm_key = deployment_id
request_count = self.cache.get_cache(key=rpm_key, local_only=True) request_count = self.cache.get_cache(
key=rpm_key, parent_otel_span=parent_otel_span, local_only=True
)
if request_count is None: if request_count is None:
request_count = 1 request_count = 1
self.cache.set_cache( self.cache.set_cache(
@ -3591,7 +3680,7 @@ class Router:
) )
return False return False
def _get_healthy_deployments(self, model: str): def _get_healthy_deployments(self, model: str, parent_otel_span: Optional[Span]):
_all_deployments: list = [] _all_deployments: list = []
try: try:
_, _all_deployments = self._common_checks_available_deployment( # type: ignore _, _all_deployments = self._common_checks_available_deployment( # type: ignore
@ -3602,7 +3691,9 @@ class Router:
except Exception: except Exception:
pass pass
unhealthy_deployments = _get_cooldown_deployments(litellm_router_instance=self) unhealthy_deployments = _get_cooldown_deployments(
litellm_router_instance=self, parent_otel_span=parent_otel_span
)
healthy_deployments: list = [] healthy_deployments: list = []
for deployment in _all_deployments: for deployment in _all_deployments:
if deployment["model_info"]["id"] in unhealthy_deployments: if deployment["model_info"]["id"] in unhealthy_deployments:
@ -3613,7 +3704,7 @@ class Router:
return healthy_deployments, _all_deployments return healthy_deployments, _all_deployments
async def _async_get_healthy_deployments( async def _async_get_healthy_deployments(
self, model: str self, model: str, parent_otel_span: Optional[Span]
) -> Tuple[List[Dict], List[Dict]]: ) -> Tuple[List[Dict], List[Dict]]:
""" """
Returns Tuple of: Returns Tuple of:
@ -3632,7 +3723,7 @@ class Router:
pass pass
unhealthy_deployments = await _async_get_cooldown_deployments( unhealthy_deployments = await _async_get_cooldown_deployments(
litellm_router_instance=self litellm_router_instance=self, parent_otel_span=parent_otel_span
) )
healthy_deployments: list = [] healthy_deployments: list = []
for deployment in _all_deployments: for deployment in _all_deployments:
@ -3659,7 +3750,10 @@ class Router:
_callback.pre_call_check(deployment) _callback.pre_call_check(deployment)
async def async_routing_strategy_pre_call_checks( async def async_routing_strategy_pre_call_checks(
self, deployment: dict, logging_obj: Optional[LiteLLMLogging] = None self,
deployment: dict,
parent_otel_span: Optional[Span],
logging_obj: Optional[LiteLLMLogging] = None,
): ):
""" """
For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore. For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore.
@ -3675,7 +3769,7 @@ class Router:
for _callback in litellm.callbacks: for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger): if isinstance(_callback, CustomLogger):
try: try:
await _callback.async_pre_call_check(deployment) await _callback.async_pre_call_check(deployment, parent_otel_span)
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
## LOG FAILURE EVENT ## LOG FAILURE EVENT
if logging_obj is not None: if logging_obj is not None:
@ -4646,14 +4740,19 @@ class Router:
The appropriate client based on the given client_type and kwargs. The appropriate client based on the given client_type and kwargs.
""" """
model_id = deployment["model_info"]["id"] model_id = deployment["model_info"]["id"]
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(kwargs)
if client_type == "max_parallel_requests": if client_type == "max_parallel_requests":
cache_key = "{}_max_parallel_requests_client".format(model_id) cache_key = "{}_max_parallel_requests_client".format(model_id)
client = self.cache.get_cache(key=cache_key, local_only=True) client = self.cache.get_cache(
key=cache_key, local_only=True, parent_otel_span=parent_otel_span
)
return client return client
elif client_type == "async": elif client_type == "async":
if kwargs.get("stream") is True: if kwargs.get("stream") is True:
cache_key = f"{model_id}_stream_async_client" cache_key = f"{model_id}_stream_async_client"
client = self.cache.get_cache(key=cache_key, local_only=True) client = self.cache.get_cache(
key=cache_key, local_only=True, parent_otel_span=parent_otel_span
)
if client is None: if client is None:
""" """
Re-initialize the client Re-initialize the client
@ -4661,11 +4760,17 @@ class Router:
InitalizeOpenAISDKClient.set_client( InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment litellm_router_instance=self, model=deployment
) )
client = self.cache.get_cache(key=cache_key, local_only=True) client = self.cache.get_cache(
key=cache_key,
local_only=True,
parent_otel_span=parent_otel_span,
)
return client return client
else: else:
cache_key = f"{model_id}_async_client" cache_key = f"{model_id}_async_client"
client = self.cache.get_cache(key=cache_key, local_only=True) client = self.cache.get_cache(
key=cache_key, local_only=True, parent_otel_span=parent_otel_span
)
if client is None: if client is None:
""" """
Re-initialize the client Re-initialize the client
@ -4673,12 +4778,18 @@ class Router:
InitalizeOpenAISDKClient.set_client( InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment litellm_router_instance=self, model=deployment
) )
client = self.cache.get_cache(key=cache_key, local_only=True) client = self.cache.get_cache(
key=cache_key,
local_only=True,
parent_otel_span=parent_otel_span,
)
return client return client
else: else:
if kwargs.get("stream") is True: if kwargs.get("stream") is True:
cache_key = f"{model_id}_stream_client" cache_key = f"{model_id}_stream_client"
client = self.cache.get_cache(key=cache_key) client = self.cache.get_cache(
key=cache_key, parent_otel_span=parent_otel_span
)
if client is None: if client is None:
""" """
Re-initialize the client Re-initialize the client
@ -4686,11 +4797,15 @@ class Router:
InitalizeOpenAISDKClient.set_client( InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment litellm_router_instance=self, model=deployment
) )
client = self.cache.get_cache(key=cache_key) client = self.cache.get_cache(
key=cache_key, parent_otel_span=parent_otel_span
)
return client return client
else: else:
cache_key = f"{model_id}_client" cache_key = f"{model_id}_client"
client = self.cache.get_cache(key=cache_key) client = self.cache.get_cache(
key=cache_key, parent_otel_span=parent_otel_span
)
if client is None: if client is None:
""" """
Re-initialize the client Re-initialize the client
@ -4698,7 +4813,9 @@ class Router:
InitalizeOpenAISDKClient.set_client( InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment litellm_router_instance=self, model=deployment
) )
client = self.cache.get_cache(key=cache_key) client = self.cache.get_cache(
key=cache_key, parent_otel_span=parent_otel_span
)
return client return client
def _pre_call_checks( # noqa: PLR0915 def _pre_call_checks( # noqa: PLR0915
@ -4738,13 +4855,17 @@ class Router:
_context_window_error = False _context_window_error = False
_potential_error_str = "" _potential_error_str = ""
_rate_limit_error = False _rate_limit_error = False
parent_otel_span = _get_parent_otel_span_from_kwargs(request_kwargs)
## get model group RPM ## ## get model group RPM ##
dt = get_utc_datetime() dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M") current_minute = dt.strftime("%H-%M")
rpm_key = f"{model}:rpm:{current_minute}" rpm_key = f"{model}:rpm:{current_minute}"
model_group_cache = ( model_group_cache = (
self.cache.get_cache(key=rpm_key, local_only=True) or {} self.cache.get_cache(
key=rpm_key, local_only=True, parent_otel_span=parent_otel_span
)
or {}
) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache. ) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache.
for idx, deployment in enumerate(_returned_deployments): for idx, deployment in enumerate(_returned_deployments):
# see if we have the info for this model # see if we have the info for this model
@ -4783,7 +4904,10 @@ class Router:
## RPM CHECK ## ## RPM CHECK ##
### get local router cache ### ### get local router cache ###
current_request_cache_local = ( current_request_cache_local = (
self.cache.get_cache(key=model_id, local_only=True) or 0 self.cache.get_cache(
key=model_id, local_only=True, parent_otel_span=parent_otel_span
)
or 0
) )
### get usage based cache ### ### get usage based cache ###
if ( if (
@ -5002,6 +5126,7 @@ class Router:
self.routing_strategy != "usage-based-routing-v2" self.routing_strategy != "usage-based-routing-v2"
and self.routing_strategy != "simple-shuffle" and self.routing_strategy != "simple-shuffle"
and self.routing_strategy != "cost-based-routing" and self.routing_strategy != "cost-based-routing"
and self.routing_strategy != "latency-based-routing"
): # prevent regressions for other routing strategies, that don't have async get available deployments implemented. ): # prevent regressions for other routing strategies, that don't have async get available deployments implemented.
return self.get_available_deployment( return self.get_available_deployment(
model=model, model=model,
@ -5011,6 +5136,7 @@ class Router:
request_kwargs=request_kwargs, request_kwargs=request_kwargs,
) )
try: try:
parent_otel_span = _get_parent_otel_span_from_kwargs(request_kwargs)
model, healthy_deployments = self._common_checks_available_deployment( model, healthy_deployments = self._common_checks_available_deployment(
model=model, model=model,
messages=messages, messages=messages,
@ -5021,7 +5147,7 @@ class Router:
return healthy_deployments return healthy_deployments
cooldown_deployments = await _async_get_cooldown_deployments( cooldown_deployments = await _async_get_cooldown_deployments(
litellm_router_instance=self litellm_router_instance=self, parent_otel_span=parent_otel_span
) )
verbose_router_logger.debug( verbose_router_logger.debug(
f"async cooldown deployments: {cooldown_deployments}" f"async cooldown deployments: {cooldown_deployments}"
@ -5059,16 +5185,18 @@ class Router:
_allowed_model_region = "n/a" _allowed_model_region = "n/a"
model_ids = self.get_model_ids(model_name=model) model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown( _cooldown_time = self.cooldown_cache.get_min_cooldown(
model_ids=model_ids model_ids=model_ids, parent_otel_span=parent_otel_span
)
_cooldown_list = _get_cooldown_deployments(
litellm_router_instance=self, parent_otel_span=parent_otel_span
) )
_cooldown_list = _get_cooldown_deployments(litellm_router_instance=self)
raise RouterRateLimitError( raise RouterRateLimitError(
model=model, model=model,
cooldown_time=_cooldown_time, cooldown_time=_cooldown_time,
enable_pre_call_checks=self.enable_pre_call_checks, enable_pre_call_checks=self.enable_pre_call_checks,
cooldown_list=_cooldown_list, cooldown_list=_cooldown_list,
) )
start_time = time.time()
if ( if (
self.routing_strategy == "usage-based-routing-v2" self.routing_strategy == "usage-based-routing-v2"
and self.lowesttpm_logger_v2 is not None and self.lowesttpm_logger_v2 is not None
@ -5093,6 +5221,19 @@ class Router:
input=input, input=input,
) )
) )
elif (
self.routing_strategy == "latency-based-routing"
and self.lowestlatency_logger is not None
):
deployment = (
await self.lowestlatency_logger.async_get_available_deployments(
model_group=model,
healthy_deployments=healthy_deployments, # type: ignore
messages=messages,
input=input,
request_kwargs=request_kwargs,
)
)
elif self.routing_strategy == "simple-shuffle": elif self.routing_strategy == "simple-shuffle":
return simple_shuffle( return simple_shuffle(
llm_router_instance=self, llm_router_instance=self,
@ -5107,9 +5248,11 @@ class Router:
) )
model_ids = self.get_model_ids(model_name=model) model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown( _cooldown_time = self.cooldown_cache.get_min_cooldown(
model_ids=model_ids model_ids=model_ids, parent_otel_span=parent_otel_span
)
_cooldown_list = _get_cooldown_deployments(
litellm_router_instance=self, parent_otel_span=parent_otel_span
) )
_cooldown_list = _get_cooldown_deployments(litellm_router_instance=self)
raise RouterRateLimitError( raise RouterRateLimitError(
model=model, model=model,
cooldown_time=_cooldown_time, cooldown_time=_cooldown_time,
@ -5120,6 +5263,19 @@ class Router:
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}"
) )
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
service=ServiceTypes.ROUTER,
duration=_duration,
call_type="<routing_strategy>.async_get_available_deployments",
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
)
)
return deployment return deployment
except Exception as e: except Exception as e:
traceback_exception = traceback.format_exc() traceback_exception = traceback.format_exc()
@ -5163,7 +5319,12 @@ class Router:
if isinstance(healthy_deployments, dict): if isinstance(healthy_deployments, dict):
return healthy_deployments return healthy_deployments
cooldown_deployments = _get_cooldown_deployments(litellm_router_instance=self) parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
request_kwargs
)
cooldown_deployments = _get_cooldown_deployments(
litellm_router_instance=self, parent_otel_span=parent_otel_span
)
healthy_deployments = self._filter_cooldown_deployments( healthy_deployments = self._filter_cooldown_deployments(
healthy_deployments=healthy_deployments, healthy_deployments=healthy_deployments,
cooldown_deployments=cooldown_deployments, cooldown_deployments=cooldown_deployments,
@ -5180,8 +5341,12 @@ class Router:
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:
model_ids = self.get_model_ids(model_name=model) model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids) _cooldown_time = self.cooldown_cache.get_min_cooldown(
_cooldown_list = _get_cooldown_deployments(litellm_router_instance=self) model_ids=model_ids, parent_otel_span=parent_otel_span
)
_cooldown_list = _get_cooldown_deployments(
litellm_router_instance=self, parent_otel_span=parent_otel_span
)
raise RouterRateLimitError( raise RouterRateLimitError(
model=model, model=model,
cooldown_time=_cooldown_time, cooldown_time=_cooldown_time,
@ -5238,8 +5403,12 @@ class Router:
f"get_available_deployment for model: {model}, No deployment available" f"get_available_deployment for model: {model}, No deployment available"
) )
model_ids = self.get_model_ids(model_name=model) model_ids = self.get_model_ids(model_name=model)
_cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids) _cooldown_time = self.cooldown_cache.get_min_cooldown(
_cooldown_list = _get_cooldown_deployments(litellm_router_instance=self) model_ids=model_ids, parent_otel_span=parent_otel_span
)
_cooldown_list = _get_cooldown_deployments(
litellm_router_instance=self, parent_otel_span=parent_otel_span
)
raise RouterRateLimitError( raise RouterRateLimitError(
model=model, model=model,
cooldown_time=_cooldown_time, cooldown_time=_cooldown_time,
@ -5278,7 +5447,9 @@ class Router:
healthy_deployments.remove(deployment) healthy_deployments.remove(deployment)
return healthy_deployments return healthy_deployments
def _track_deployment_metrics(self, deployment, response=None): def _track_deployment_metrics(
self, deployment, parent_otel_span: Optional[Span], response=None
):
""" """
Tracks successful requests rpm usage. Tracks successful requests rpm usage.
""" """
@ -5288,7 +5459,9 @@ class Router:
# update self.deployment_stats # update self.deployment_stats
if model_id is not None: if model_id is not None:
self._update_usage(model_id) # update in-memory cache for tracking self._update_usage(
model_id, parent_otel_span
) # update in-memory cache for tracking
except Exception as e: except Exception as e:
verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}") verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}")

View file

@ -3,7 +3,7 @@
import random import random
import traceback import traceback
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -11,6 +11,14 @@ import litellm
from litellm import ModelResponse, token_counter, verbose_logger from litellm import ModelResponse, token_counter, verbose_logger
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class LiteLLMBase(BaseModel): class LiteLLMBase(BaseModel):
@ -115,8 +123,13 @@ class LowestLatencyLoggingHandler(CustomLogger):
# ------------ # ------------
# Update usage # Update usage
# ------------ # ------------
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
request_count_dict = self.router_cache.get_cache(key=latency_key) or {} request_count_dict = (
self.router_cache.get_cache(
key=latency_key, parent_otel_span=parent_otel_span
)
or {}
)
if id not in request_count_dict: if id not in request_count_dict:
request_count_dict[id] = {} request_count_dict[id] = {}
@ -213,7 +226,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
""" """
latency_key = f"{model_group}_map" latency_key = f"{model_group}_map"
request_count_dict = ( request_count_dict = (
self.router_cache.get_cache(key=latency_key) or {} await self.router_cache.async_get_cache(key=latency_key) or {}
) )
if id not in request_count_dict: if id not in request_count_dict:
@ -316,8 +329,15 @@ class LowestLatencyLoggingHandler(CustomLogger):
# ------------ # ------------
# Update usage # Update usage
# ------------ # ------------
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
request_count_dict = self.router_cache.get_cache(key=latency_key) or {} request_count_dict = (
await self.router_cache.async_get_cache(
key=latency_key,
parent_otel_span=parent_otel_span,
local_only=True,
)
or {}
)
if id not in request_count_dict: if id not in request_count_dict:
request_count_dict[id] = {} request_count_dict[id] = {}
@ -379,26 +399,21 @@ class LowestLatencyLoggingHandler(CustomLogger):
) )
pass pass
def get_available_deployments( # noqa: PLR0915 def _get_available_deployments( # noqa: PLR0915
self, self,
model_group: str, model_group: str,
healthy_deployments: list, healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None, messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None, input: Optional[Union[str, List]] = None,
request_kwargs: Optional[Dict] = None, request_kwargs: Optional[Dict] = None,
request_count_dict: Optional[Dict] = None,
): ):
""" """Common logic for both sync and async get_available_deployments"""
Returns a deployment with the lowest latency
"""
# get list of potential deployments
latency_key = f"{model_group}_map"
_latency_per_deployment = {}
request_count_dict = self.router_cache.get_cache(key=latency_key) or {}
# ----------------------- # -----------------------
# Find lowest used model # Find lowest used model
# ---------------------- # ----------------------
_latency_per_deployment = {}
lowest_latency = float("inf") lowest_latency = float("inf")
current_date = datetime.now().strftime("%Y-%m-%d") current_date = datetime.now().strftime("%Y-%m-%d")
@ -428,8 +443,8 @@ class LowestLatencyLoggingHandler(CustomLogger):
# randomly sample from all_deployments, incase all deployments have latency=0.0 # randomly sample from all_deployments, incase all deployments have latency=0.0
_items = all_deployments.items() _items = all_deployments.items()
all_deployments = random.sample(list(_items), len(_items)) _all_deployments = random.sample(list(_items), len(_items))
all_deployments = dict(all_deployments) all_deployments = dict(_all_deployments)
### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits ### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits
potential_deployments = [] potential_deployments = []
@ -525,3 +540,66 @@ class LowestLatencyLoggingHandler(CustomLogger):
"_latency_per_deployment" "_latency_per_deployment"
] = _latency_per_deployment ] = _latency_per_deployment
return deployment return deployment
async def async_get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
request_kwargs: Optional[Dict] = None,
):
# get list of potential deployments
latency_key = f"{model_group}_map"
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
request_kwargs
)
request_count_dict = (
await self.router_cache.async_get_cache(
key=latency_key, parent_otel_span=parent_otel_span
)
or {}
)
return self._get_available_deployments(
model_group,
healthy_deployments,
messages,
input,
request_kwargs,
request_count_dict,
)
def get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
request_kwargs: Optional[Dict] = None,
):
"""
Returns a deployment with the lowest latency
"""
# get list of potential deployments
latency_key = f"{model_group}_map"
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
request_kwargs
)
request_count_dict = (
self.router_cache.get_cache(
key=latency_key, parent_otel_span=parent_otel_span
)
or {}
)
return self._get_available_deployments(
model_group,
healthy_deployments,
messages,
input,
request_kwargs,
request_count_dict,
)

View file

@ -2,7 +2,7 @@
# identifies lowest tpm deployment # identifies lowest tpm deployment
import random import random
import traceback import traceback
from typing import Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
@ -12,9 +12,17 @@ from litellm import token_counter
from litellm._logging import verbose_logger, verbose_router_logger from litellm._logging import verbose_logger, verbose_router_logger
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.types.router import RouterErrors from litellm.types.router import RouterErrors
from litellm.utils import get_utc_datetime, print_verbose from litellm.utils import get_utc_datetime, print_verbose
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class LiteLLMBase(BaseModel): class LiteLLMBase(BaseModel):
""" """
@ -136,7 +144,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
raise e raise e
return deployment # don't fail calls if eg. redis fails to connect return deployment # don't fail calls if eg. redis fails to connect
async def async_pre_call_check(self, deployment: Dict) -> Optional[Dict]: async def async_pre_call_check(
self, deployment: Dict, parent_otel_span: Optional[Span]
) -> Optional[Dict]:
""" """
Pre-call check + update model rpm Pre-call check + update model rpm
- Used inside semaphore - Used inside semaphore
@ -192,7 +202,10 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
else: else:
# if local result below limit, check redis ## prevent unnecessary redis checks # if local result below limit, check redis ## prevent unnecessary redis checks
result = await self.router_cache.async_increment_cache( result = await self.router_cache.async_increment_cache(
key=rpm_key, value=1, ttl=self.routing_args.ttl key=rpm_key,
value=1,
ttl=self.routing_args.ttl,
parent_otel_span=parent_otel_span,
) )
if result is not None and result > deployment_rpm: if result is not None and result > deployment_rpm:
raise litellm.RateLimitError( raise litellm.RateLimitError(
@ -301,10 +314,13 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
# Update usage # Update usage
# ------------ # ------------
# update cache # update cache
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
## TPM ## TPM
await self.router_cache.async_increment_cache( await self.router_cache.async_increment_cache(
key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl key=tpm_key,
value=total_tokens,
ttl=self.routing_args.ttl,
parent_otel_span=parent_otel_span,
) )
### TESTING ### ### TESTING ###
@ -547,6 +563,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
healthy_deployments: list, healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None, messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None, input: Optional[Union[str, List]] = None,
parent_otel_span: Optional[Span] = None,
): ):
""" """
Returns a deployment with the lowest TPM/RPM usage. Returns a deployment with the lowest TPM/RPM usage.
@ -572,10 +589,10 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
rpm_keys.append(rpm_key) rpm_keys.append(rpm_key)
tpm_values = self.router_cache.batch_get_cache( tpm_values = self.router_cache.batch_get_cache(
keys=tpm_keys keys=tpm_keys, parent_otel_span=parent_otel_span
) # [1, 2, None, ..] ) # [1, 2, None, ..]
rpm_values = self.router_cache.batch_get_cache( rpm_values = self.router_cache.batch_get_cache(
keys=rpm_keys keys=rpm_keys, parent_otel_span=parent_otel_span
) # [1, 2, None, ..] ) # [1, 2, None, ..]
deployment = self._common_checks_available_deployment( deployment = self._common_checks_available_deployment(

View file

@ -4,11 +4,18 @@ Wrapper around router cache. Meant to handle model cooldown logic
import json import json
import time import time
from typing import List, Optional, Tuple, TypedDict from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict
from litellm import verbose_logger from litellm import verbose_logger
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class CooldownCacheValue(TypedDict): class CooldownCacheValue(TypedDict):
exception_received: str exception_received: str
@ -77,13 +84,18 @@ class CooldownCache:
raise e raise e
async def async_get_active_cooldowns( async def async_get_active_cooldowns(
self, model_ids: List[str] self, model_ids: List[str], parent_otel_span: Optional[Span]
) -> List[Tuple[str, CooldownCacheValue]]: ) -> List[Tuple[str, CooldownCacheValue]]:
# Generate the keys for the deployments # Generate the keys for the deployments
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids] keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
# Retrieve the values for the keys using mget # Retrieve the values for the keys using mget
results = await self.cache.async_batch_get_cache(keys=keys) or [] results = (
await self.cache.async_batch_get_cache(
keys=keys, parent_otel_span=parent_otel_span
)
or []
)
active_cooldowns = [] active_cooldowns = []
# Process the results # Process the results
@ -95,13 +107,15 @@ class CooldownCache:
return active_cooldowns return active_cooldowns
def get_active_cooldowns( def get_active_cooldowns(
self, model_ids: List[str] self, model_ids: List[str], parent_otel_span: Optional[Span]
) -> List[Tuple[str, CooldownCacheValue]]: ) -> List[Tuple[str, CooldownCacheValue]]:
# Generate the keys for the deployments # Generate the keys for the deployments
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids] keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
# Retrieve the values for the keys using mget # Retrieve the values for the keys using mget
results = self.cache.batch_get_cache(keys=keys) or [] results = (
self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
or []
)
active_cooldowns = [] active_cooldowns = []
# Process the results # Process the results
@ -112,14 +126,19 @@ class CooldownCache:
return active_cooldowns return active_cooldowns
def get_min_cooldown(self, model_ids: List[str]) -> float: def get_min_cooldown(
self, model_ids: List[str], parent_otel_span: Optional[Span]
) -> float:
"""Return min cooldown time required for a group of model id's.""" """Return min cooldown time required for a group of model id's."""
# Generate the keys for the deployments # Generate the keys for the deployments
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids] keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
# Retrieve the values for the keys using mget # Retrieve the values for the keys using mget
results = self.cache.batch_get_cache(keys=keys) or [] results = (
self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
or []
)
min_cooldown_time: Optional[float] = None min_cooldown_time: Optional[float] = None
# Process the results # Process the results

View file

@ -20,12 +20,15 @@ from .router_callbacks.track_deployment_metrics import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.router import Router as _Router from litellm.router import Router as _Router
LitellmRouter = _Router LitellmRouter = _Router
Span = _Span
else: else:
LitellmRouter = Any LitellmRouter = Any
Span = Any
DEFAULT_FAILURE_THRESHOLD_PERCENT = ( DEFAULT_FAILURE_THRESHOLD_PERCENT = (
0.5 # default cooldown a deployment if 50% of requests fail in a given minute 0.5 # default cooldown a deployment if 50% of requests fail in a given minute
) )
@ -207,6 +210,7 @@ def _set_cooldown_deployments(
async def _async_get_cooldown_deployments( async def _async_get_cooldown_deployments(
litellm_router_instance: LitellmRouter, litellm_router_instance: LitellmRouter,
parent_otel_span: Optional[Span],
) -> List[str]: ) -> List[str]:
""" """
Async implementation of '_get_cooldown_deployments' Async implementation of '_get_cooldown_deployments'
@ -214,7 +218,8 @@ async def _async_get_cooldown_deployments(
model_ids = litellm_router_instance.get_model_ids() model_ids = litellm_router_instance.get_model_ids()
cooldown_models = ( cooldown_models = (
await litellm_router_instance.cooldown_cache.async_get_active_cooldowns( await litellm_router_instance.cooldown_cache.async_get_active_cooldowns(
model_ids=model_ids model_ids=model_ids,
parent_otel_span=parent_otel_span,
) )
) )
@ -233,6 +238,7 @@ async def _async_get_cooldown_deployments(
async def _async_get_cooldown_deployments_with_debug_info( async def _async_get_cooldown_deployments_with_debug_info(
litellm_router_instance: LitellmRouter, litellm_router_instance: LitellmRouter,
parent_otel_span: Optional[Span],
) -> List[tuple]: ) -> List[tuple]:
""" """
Async implementation of '_get_cooldown_deployments' Async implementation of '_get_cooldown_deployments'
@ -240,7 +246,7 @@ async def _async_get_cooldown_deployments_with_debug_info(
model_ids = litellm_router_instance.get_model_ids() model_ids = litellm_router_instance.get_model_ids()
cooldown_models = ( cooldown_models = (
await litellm_router_instance.cooldown_cache.async_get_active_cooldowns( await litellm_router_instance.cooldown_cache.async_get_active_cooldowns(
model_ids=model_ids model_ids=model_ids, parent_otel_span=parent_otel_span
) )
) )
@ -248,7 +254,9 @@ async def _async_get_cooldown_deployments_with_debug_info(
return cooldown_models return cooldown_models
def _get_cooldown_deployments(litellm_router_instance: LitellmRouter) -> List[str]: def _get_cooldown_deployments(
litellm_router_instance: LitellmRouter, parent_otel_span: Optional[Span]
) -> List[str]:
""" """
Get the list of models being cooled down for this minute Get the list of models being cooled down for this minute
""" """
@ -258,8 +266,9 @@ def _get_cooldown_deployments(litellm_router_instance: LitellmRouter) -> List[st
# Return cooldown models # Return cooldown models
# ---------------------- # ----------------------
model_ids = litellm_router_instance.get_model_ids() model_ids = litellm_router_instance.get_model_ids()
cooldown_models = litellm_router_instance.cooldown_cache.get_active_cooldowns( cooldown_models = litellm_router_instance.cooldown_cache.get_active_cooldowns(
model_ids=model_ids model_ids=model_ids, parent_otel_span=parent_otel_span
) )
cached_value_deployment_ids = [] cached_value_deployment_ids = []

View file

@ -14,6 +14,7 @@ class ServiceTypes(str, enum.Enum):
DB = "postgres" DB = "postgres"
BATCH_WRITE_TO_DB = "batch_write_to_db" BATCH_WRITE_TO_DB = "batch_write_to_db"
LITELLM = "self" LITELLM = "self"
ROUTER = "router"
class ServiceLoggerPayload(BaseModel): class ServiceLoggerPayload(BaseModel):

View file

@ -83,7 +83,9 @@ def test_dual_cache_batch_get_cache():
in_memory_cache.set_cache(key="test_value", value="hello world") in_memory_cache.set_cache(key="test_value", value="hello world")
result = dual_cache.batch_get_cache(keys=["test_value", "test_value_2"]) result = dual_cache.batch_get_cache(
keys=["test_value", "test_value_2"], parent_otel_span=None
)
assert result[0] == "hello world" assert result[0] == "hello world"
assert result[1] == None assert result[1] == None

View file

@ -2447,11 +2447,11 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode):
if sync_mode: if sync_mode:
cooldown_deployments = _get_cooldown_deployments( cooldown_deployments = _get_cooldown_deployments(
litellm_router_instance=router litellm_router_instance=router, parent_otel_span=None
) )
else: else:
cooldown_deployments = await _async_get_cooldown_deployments( cooldown_deployments = await _async_get_cooldown_deployments(
litellm_router_instance=router litellm_router_instance=router, parent_otel_span=None
) )
print( print(
"Cooldown deployments - {}\n{}".format( "Cooldown deployments - {}\n{}".format(

View file

@ -242,12 +242,12 @@ async def test_single_deployment_no_cooldowns_test_prod_mock_completion_calls():
pass pass
cooldown_list = await _async_get_cooldown_deployments( cooldown_list = await _async_get_cooldown_deployments(
litellm_router_instance=router litellm_router_instance=router, parent_otel_span=None
) )
assert len(cooldown_list) == 0 assert len(cooldown_list) == 0
healthy_deployments, _ = await router._async_get_healthy_deployments( healthy_deployments, _ = await router._async_get_healthy_deployments(
model="gpt-3.5-turbo" model="gpt-3.5-turbo", parent_otel_span=None
) )
print("healthy_deployments: ", healthy_deployments) print("healthy_deployments: ", healthy_deployments)
@ -351,7 +351,7 @@ async def test_high_traffic_cooldowns_all_healthy_deployments():
print("model_stats: ", model_stats) print("model_stats: ", model_stats)
cooldown_list = await _async_get_cooldown_deployments( cooldown_list = await _async_get_cooldown_deployments(
litellm_router_instance=router litellm_router_instance=router, parent_otel_span=None
) )
assert len(cooldown_list) == 0 assert len(cooldown_list) == 0
@ -449,7 +449,7 @@ async def test_high_traffic_cooldowns_one_bad_deployment():
print("model_stats: ", model_stats) print("model_stats: ", model_stats)
cooldown_list = await _async_get_cooldown_deployments( cooldown_list = await _async_get_cooldown_deployments(
litellm_router_instance=router litellm_router_instance=router, parent_otel_span=None
) )
assert len(cooldown_list) == 1 assert len(cooldown_list) == 1
@ -550,7 +550,7 @@ async def test_high_traffic_cooldowns_one_rate_limited_deployment():
print("model_stats: ", model_stats) print("model_stats: ", model_stats)
cooldown_list = await _async_get_cooldown_deployments( cooldown_list = await _async_get_cooldown_deployments(
litellm_router_instance=router litellm_router_instance=router, parent_otel_span=None
) )
assert len(cooldown_list) == 1 assert len(cooldown_list) == 1

View file

@ -440,12 +440,12 @@ def test_update_usage(model_list):
) )
deployment_id = deployment["model_info"]["id"] deployment_id = deployment["model_info"]["id"]
request_count = router._update_usage( request_count = router._update_usage(
deployment_id=deployment_id, deployment_id=deployment_id, parent_otel_span=None
) )
assert request_count == 1 assert request_count == 1
request_count = router._update_usage( request_count = router._update_usage(
deployment_id=deployment_id, deployment_id=deployment_id, parent_otel_span=None
) )
assert request_count == 2 assert request_count == 2
@ -482,7 +482,9 @@ def test_should_raise_content_policy_error(model_list, finish_reason, expected_e
def test_get_healthy_deployments(model_list): def test_get_healthy_deployments(model_list):
"""Test if the '_get_healthy_deployments' function is working correctly""" """Test if the '_get_healthy_deployments' function is working correctly"""
router = Router(model_list=model_list) router = Router(model_list=model_list)
deployments = router._get_healthy_deployments(model="gpt-3.5-turbo") deployments = router._get_healthy_deployments(
model="gpt-3.5-turbo", parent_otel_span=None
)
assert len(deployments) > 0 assert len(deployments) > 0
@ -756,6 +758,7 @@ def test_track_deployment_metrics(model_list):
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
usage={"total_tokens": 100}, usage={"total_tokens": 100},
), ),
parent_otel_span=None,
) )