From 4f8a3fd4cfc20cf43b38379928b41c2691c85d36 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Mon, 28 Oct 2024 21:52:12 -0700 Subject: [PATCH] redis otel tracing + async support for latency routing (#6452) * docs(exception_mapping.md): add missing exception types Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183 * fix(main.py): register custom model pricing with specific key Ensure custom model pricing is registered to the specific model+provider key combination * test: make testing more robust for custom pricing * fix(redis_cache.py): instrument otel logging for sync redis calls ensures complete coverage for all redis cache calls * refactor: pass parent_otel_span for redis caching calls in router allows for more observability into what calls are causing latency issues * test: update tests with new params * refactor: ensure e2e otel tracing for router * refactor(router.py): add more otel tracing acrosss router catch all latency issues for router requests * fix: fix linting error * fix(router.py): fix linting error * fix: fix test * test: fix tests * fix(dual_cache.py): pass ttl to redis cache * fix: fix param --- litellm/_service_logger.py | 6 +- litellm/caching/base_cache.py | 9 +- litellm/caching/disk_cache.py | 9 +- litellm/caching/dual_cache.py | 67 ++++- litellm/caching/redis_cache.py | 36 ++- .../SlackAlerting/slack_alerting.py | 5 +- litellm/integrations/custom_logger.py | 13 +- litellm/integrations/opentelemetry.py | 34 ++- litellm/proxy/_experimental/out/404.html | 1 - .../proxy/_experimental/out/model_hub.html | 1 - .../proxy/_experimental/out/onboarding.html | 1 - litellm/proxy/_new_secret_config.yaml | 2 +- litellm/proxy/auth/user_api_key_auth.py | 1 + litellm/proxy/hooks/max_budget_limiter.py | 4 +- litellm/proxy/utils.py | 9 +- litellm/router.py | 283 ++++++++++++++---- litellm/router_strategy/lowest_latency.py | 112 +++++-- litellm/router_strategy/lowest_tpm_rpm_v2.py | 31 +- litellm/router_utils/cooldown_cache.py | 35 ++- litellm/router_utils/cooldown_handlers.py | 19 +- litellm/types/services.py | 1 + tests/local_testing/test_caching.py | 4 +- tests/local_testing/test_router.py | 4 +- tests/local_testing/test_router_cooldowns.py | 10 +- .../test_router_helper_utils.py | 9 +- 25 files changed, 559 insertions(+), 147 deletions(-) delete mode 100644 litellm/proxy/_experimental/out/404.html delete mode 100644 litellm/proxy/_experimental/out/model_hub.html delete mode 100644 litellm/proxy/_experimental/out/onboarding.html diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py index 0e738561b..4db645e66 100644 --- a/litellm/_service_logger.py +++ b/litellm/_service_logger.py @@ -44,8 +44,7 @@ class ServiceLogging(CustomLogger): """ Handles both sync and async monitoring by checking for existing event loop. """ - # if service == ServiceTypes.REDIS: - # print(f"SYNC service: {service}, call_type: {call_type}") + if self.mock_testing: self.mock_testing_sync_success_hook += 1 @@ -112,8 +111,7 @@ class ServiceLogging(CustomLogger): """ - For counting if the redis, postgres call is successful """ - # if service == ServiceTypes.REDIS: - # print(f"service: {service}, call_type: {call_type}") + if self.mock_testing: self.mock_testing_async_success_hook += 1 diff --git a/litellm/caching/base_cache.py b/litellm/caching/base_cache.py index 016ad70b9..0699832ab 100644 --- a/litellm/caching/base_cache.py +++ b/litellm/caching/base_cache.py @@ -8,7 +8,14 @@ Has 4 methods: - async_get_cache """ -from typing import Optional +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any class BaseCache: diff --git a/litellm/caching/disk_cache.py b/litellm/caching/disk_cache.py index 830d21d9c..2c086ed50 100644 --- a/litellm/caching/disk_cache.py +++ b/litellm/caching/disk_cache.py @@ -1,10 +1,17 @@ import json -from typing import Optional +from typing import TYPE_CHECKING, Any, Optional from litellm._logging import print_verbose from .base_cache import BaseCache +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + class DiskCache(BaseCache): def __init__(self, disk_cache_dir: Optional[str] = None): diff --git a/litellm/caching/dual_cache.py b/litellm/caching/dual_cache.py index 028aa59bf..35659b865 100644 --- a/litellm/caching/dual_cache.py +++ b/litellm/caching/dual_cache.py @@ -9,7 +9,7 @@ Has 4 primary methods: """ import traceback -from typing import List, Optional +from typing import TYPE_CHECKING, Any, List, Optional import litellm from litellm._logging import print_verbose, verbose_logger @@ -18,6 +18,13 @@ from .base_cache import BaseCache from .in_memory_cache import InMemoryCache from .redis_cache import RedisCache +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + class DualCache(BaseCache): """ @@ -90,7 +97,13 @@ class DualCache(BaseCache): verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") raise e - def get_cache(self, key, local_only: bool = False, **kwargs): + def get_cache( + self, + key, + parent_otel_span: Optional[Span] = None, + local_only: bool = False, + **kwargs, + ): # Try to fetch from in-memory cache first try: result = None @@ -102,7 +115,9 @@ class DualCache(BaseCache): if result is None and self.redis_cache is not None and local_only is False: # If not found in in-memory cache, try fetching from Redis - redis_result = self.redis_cache.get_cache(key, **kwargs) + redis_result = self.redis_cache.get_cache( + key, parent_otel_span=parent_otel_span + ) if redis_result is not None: # Update in-memory cache with the value from Redis @@ -115,7 +130,13 @@ class DualCache(BaseCache): except Exception: verbose_logger.error(traceback.format_exc()) - def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs): + def batch_get_cache( + self, + keys: list, + parent_otel_span: Optional[Span], + local_only: bool = False, + **kwargs, + ): try: result = [None for _ in range(len(keys))] if self.in_memory_cache is not None: @@ -133,7 +154,9 @@ class DualCache(BaseCache): key for key, value in zip(keys, result) if value is None ] # If not found in in-memory cache, try fetching from Redis - redis_result = self.redis_cache.batch_get_cache(sublist_keys, **kwargs) + redis_result = self.redis_cache.batch_get_cache( + sublist_keys, parent_otel_span=parent_otel_span + ) if redis_result is not None: # Update in-memory cache with the value from Redis for key in redis_result: @@ -147,7 +170,13 @@ class DualCache(BaseCache): except Exception: verbose_logger.error(traceback.format_exc()) - async def async_get_cache(self, key, local_only: bool = False, **kwargs): + async def async_get_cache( + self, + key, + parent_otel_span: Optional[Span] = None, + local_only: bool = False, + **kwargs, + ): # Try to fetch from in-memory cache first try: print_verbose( @@ -165,7 +194,9 @@ class DualCache(BaseCache): if result is None and self.redis_cache is not None and local_only is False: # If not found in in-memory cache, try fetching from Redis - redis_result = await self.redis_cache.async_get_cache(key, **kwargs) + redis_result = await self.redis_cache.async_get_cache( + key, parent_otel_span=parent_otel_span + ) if redis_result is not None: # Update in-memory cache with the value from Redis @@ -181,7 +212,11 @@ class DualCache(BaseCache): verbose_logger.error(traceback.format_exc()) async def async_batch_get_cache( - self, keys: list, local_only: bool = False, **kwargs + self, + keys: list, + parent_otel_span: Optional[Span] = None, + local_only: bool = False, + **kwargs, ): try: result = [None for _ in range(len(keys))] @@ -202,7 +237,7 @@ class DualCache(BaseCache): ] # If not found in in-memory cache, try fetching from Redis redis_result = await self.redis_cache.async_batch_get_cache( - sublist_keys, **kwargs + sublist_keys, parent_otel_span=parent_otel_span ) if redis_result is not None: @@ -260,7 +295,12 @@ class DualCache(BaseCache): ) async def async_increment_cache( - self, key, value: float, local_only: bool = False, **kwargs + self, + key, + value: float, + parent_otel_span: Optional[Span], + local_only: bool = False, + **kwargs, ) -> float: """ Key - the key in cache @@ -277,7 +317,12 @@ class DualCache(BaseCache): ) if self.redis_cache is not None and local_only is False: - result = await self.redis_cache.async_increment(key, value, **kwargs) + result = await self.redis_cache.async_increment( + key, + value, + parent_otel_span=parent_otel_span, + ttl=kwargs.get("ttl", None), + ) return result except Exception as e: diff --git a/litellm/caching/redis_cache.py b/litellm/caching/redis_cache.py index 0160f2f0c..e6c408cc8 100644 --- a/litellm/caching/redis_cache.py +++ b/litellm/caching/redis_cache.py @@ -13,6 +13,7 @@ import asyncio import inspect import json import time +import traceback from datetime import timedelta from typing import TYPE_CHECKING, Any, List, Optional, Tuple @@ -25,14 +26,17 @@ from litellm.types.utils import all_litellm_params from .base_cache import BaseCache if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span from redis.asyncio import Redis from redis.asyncio.client import Pipeline pipeline = Pipeline async_redis_client = Redis + Span = _Span else: pipeline = Any async_redis_client = Any + Span = Any class RedisCache(BaseCache): @@ -524,7 +528,11 @@ class RedisCache(BaseCache): await self.flush_cache_buffer() # logging done in here async def async_increment( - self, key, value: float, ttl: Optional[int] = None, **kwargs + self, + key, + value: float, + ttl: Optional[int] = None, + parent_otel_span: Optional[Span] = None, ) -> float: from redis.asyncio import Redis @@ -552,7 +560,7 @@ class RedisCache(BaseCache): call_type="async_increment", start_time=start_time, end_time=end_time, - parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + parent_otel_span=parent_otel_span, ) ) return result @@ -568,7 +576,7 @@ class RedisCache(BaseCache): call_type="async_increment", start_time=start_time, end_time=end_time, - parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + parent_otel_span=parent_otel_span, ) ) verbose_logger.error( @@ -601,7 +609,7 @@ class RedisCache(BaseCache): cached_response = ast.literal_eval(cached_response) return cached_response - def get_cache(self, key, **kwargs): + def get_cache(self, key, parent_otel_span: Optional[Span] = None, **kwargs): try: key = self.check_and_fix_namespace(key=key) print_verbose(f"Get Redis Cache: key: {key}") @@ -615,6 +623,7 @@ class RedisCache(BaseCache): call_type="get_cache", start_time=start_time, end_time=end_time, + parent_otel_span=parent_otel_span, ) print_verbose( f"Got Redis Cache: key: {key}, cached_response {cached_response}" @@ -626,11 +635,12 @@ class RedisCache(BaseCache): "litellm.caching.caching: get() - Got exception from REDIS: ", e ) - def batch_get_cache(self, key_list) -> dict: + def batch_get_cache(self, key_list, parent_otel_span: Optional[Span]) -> dict: """ Use Redis for bulk read operations """ key_value_dict = {} + try: _keys = [] for cache_key in key_list: @@ -646,6 +656,7 @@ class RedisCache(BaseCache): call_type="batch_get_cache", start_time=start_time, end_time=end_time, + parent_otel_span=parent_otel_span, ) # Associate the results back with their keys. @@ -662,7 +673,9 @@ class RedisCache(BaseCache): print_verbose(f"Error occurred in pipeline read - {str(e)}") return key_value_dict - async def async_get_cache(self, key, **kwargs): + async def async_get_cache( + self, key, parent_otel_span: Optional[Span] = None, **kwargs + ): from redis.asyncio import Redis _redis_client: Redis = self.init_async_client() # type: ignore @@ -686,7 +699,7 @@ class RedisCache(BaseCache): call_type="async_get_cache", start_time=start_time, end_time=end_time, - parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + parent_otel_span=parent_otel_span, event_metadata={"key": key}, ) ) @@ -703,7 +716,7 @@ class RedisCache(BaseCache): call_type="async_get_cache", start_time=start_time, end_time=end_time, - parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + parent_otel_span=parent_otel_span, event_metadata={"key": key}, ) ) @@ -712,10 +725,13 @@ class RedisCache(BaseCache): f"litellm.caching.caching: async get() - Got exception from REDIS: {str(e)}" ) - async def async_batch_get_cache(self, key_list) -> dict: + async def async_batch_get_cache( + self, key_list: List[str], parent_otel_span: Optional[Span] = None + ) -> dict: """ Use Redis for bulk read operations """ + _redis_client = await self.init_async_client() key_value_dict = {} start_time = time.time() @@ -737,6 +753,7 @@ class RedisCache(BaseCache): call_type="async_batch_get_cache", start_time=start_time, end_time=end_time, + parent_otel_span=parent_otel_span, ) ) @@ -764,6 +781,7 @@ class RedisCache(BaseCache): call_type="async_batch_get_cache", start_time=start_time, end_time=end_time, + parent_otel_span=parent_otel_span, ) ) print_verbose(f"Error occurred in pipeline read - {str(e)}") diff --git a/litellm/integrations/SlackAlerting/slack_alerting.py b/litellm/integrations/SlackAlerting/slack_alerting.py index 92223aff3..dbe9e4161 100644 --- a/litellm/integrations/SlackAlerting/slack_alerting.py +++ b/litellm/integrations/SlackAlerting/slack_alerting.py @@ -268,6 +268,7 @@ class SlackAlerting(CustomBatchLogger): SlackAlertingCacheKeys.failed_requests_key.value, ), value=1, + parent_otel_span=None, # no attached request, this is a background operation ) return_val += 1 @@ -279,6 +280,7 @@ class SlackAlerting(CustomBatchLogger): deployment_metrics.id, SlackAlertingCacheKeys.latency_key.value ), value=deployment_metrics.latency_per_output_token, + parent_otel_span=None, # no attached request, this is a background operation ) return_val += 1 @@ -1518,7 +1520,8 @@ Model Info: report_sent_bool = False report_sent = await self.internal_usage_cache.async_get_cache( - key=SlackAlertingCacheKeys.report_sent_key.value + key=SlackAlertingCacheKeys.report_sent_key.value, + parent_otel_span=None, ) # None | float current_time = time.time() diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index bdcb6a52f..d62bd3e4d 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -3,7 +3,7 @@ import os import traceback from datetime import datetime as datetimeObj -from typing import Any, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, Union import dotenv from pydantic import BaseModel @@ -21,6 +21,13 @@ from litellm.types.utils import ( StandardLoggingPayload, ) +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class # Class variables or attributes @@ -62,7 +69,9 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac Allows usage-based-routing-v2 to run pre-call rpm checks within the picked deployment's semaphore (concurrency-safe tpm/rpm checks). """ - async def async_pre_call_check(self, deployment: dict) -> Optional[dict]: + async def async_pre_call_check( + self, deployment: dict, parent_otel_span: Optional[Span] + ) -> Optional[dict]: pass def pre_call_check(self, deployment: dict) -> Optional[dict]: diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index f1b7abbbb..a706cebbf 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -8,7 +8,12 @@ import litellm from litellm._logging import verbose_logger from litellm.integrations.custom_logger import CustomLogger from litellm.types.services import ServiceLoggerPayload -from litellm.types.utils import StandardLoggingPayload +from litellm.types.utils import ( + EmbeddingResponse, + ImageResponse, + ModelResponse, + StandardLoggingPayload, +) if TYPE_CHECKING: from opentelemetry.trace import Span as _Span @@ -136,12 +141,12 @@ class OpenTelemetry(CustomLogger): _end_time_ns = 0 if isinstance(start_time, float): - _start_time_ns = int(int(start_time) * 1e9) + _start_time_ns = int(start_time * 1e9) else: _start_time_ns = self._to_ns(start_time) if isinstance(end_time, float): - _end_time_ns = int(int(end_time) * 1e9) + _end_time_ns = int(end_time * 1e9) else: _end_time_ns = self._to_ns(end_time) @@ -276,6 +281,21 @@ class OpenTelemetry(CustomLogger): # End Parent OTEL Sspan parent_otel_span.end(end_time=self._to_ns(datetime.now())) + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse], + ): + from opentelemetry import trace + from opentelemetry.trace import Status, StatusCode + + parent_otel_span = user_api_key_dict.parent_otel_span + if parent_otel_span is not None: + parent_otel_span.set_status(Status(StatusCode.OK)) + # End Parent OTEL Sspan + parent_otel_span.end(end_time=self._to_ns(datetime.now())) + def _handle_sucess(self, kwargs, response_obj, start_time, end_time): from opentelemetry import trace from opentelemetry.trace import Status, StatusCode @@ -314,8 +334,8 @@ class OpenTelemetry(CustomLogger): span.end(end_time=self._to_ns(end_time)) - if parent_otel_span is not None: - parent_otel_span.end(end_time=self._to_ns(datetime.now())) + # if parent_otel_span is not None: + # parent_otel_span.end(end_time=self._to_ns(datetime.now())) def _handle_failure(self, kwargs, response_obj, start_time, end_time): from opentelemetry.trace import Status, StatusCode @@ -808,12 +828,12 @@ class OpenTelemetry(CustomLogger): end_time = logging_payload.end_time if isinstance(start_time, float): - _start_time_ns = int(int(start_time) * 1e9) + _start_time_ns = int(start_time * 1e9) else: _start_time_ns = self._to_ns(start_time) if isinstance(end_time, float): - _end_time_ns = int(int(end_time) * 1e9) + _end_time_ns = int(end_time * 1e9) else: _end_time_ns = self._to_ns(end_time) diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404.html deleted file mode 100644 index 40924db8d..000000000 --- a/litellm/proxy/_experimental/out/404.html +++ /dev/null @@ -1 +0,0 @@ -404: This page could not be found.LiteLLM Dashboard

404

This page could not be found.

\ No newline at end of file diff --git a/litellm/proxy/_experimental/out/model_hub.html b/litellm/proxy/_experimental/out/model_hub.html deleted file mode 100644 index bbf4caf53..000000000 --- a/litellm/proxy/_experimental/out/model_hub.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html deleted file mode 100644 index 82b9b495f..000000000 --- a/litellm/proxy/_experimental/out/onboarding.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index ad045adb5..5de5413ed 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -25,4 +25,4 @@ router_settings: ttl: 300 redis_host: os.environ/REDIS_HOST redis_port: os.environ/REDIS_PORT - redis_password: os.environ/REDIS_PASSWORD + redis_password: os.environ/REDIS_PASSWORD \ No newline at end of file diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index bbdddeee9..a8cc9193e 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -260,6 +260,7 @@ async def user_api_key_auth( # noqa: PLR0915 headers=request.headers ), ) + ### USER-DEFINED AUTH FUNCTION ### if user_custom_auth is not None: response = await user_custom_auth(request=request, api_key=api_key) # type: ignore diff --git a/litellm/proxy/hooks/max_budget_limiter.py b/litellm/proxy/hooks/max_budget_limiter.py index 8fa7a33a0..c1c5b4b80 100644 --- a/litellm/proxy/hooks/max_budget_limiter.py +++ b/litellm/proxy/hooks/max_budget_limiter.py @@ -28,7 +28,9 @@ class _PROXY_MaxBudgetLimiter(CustomLogger): try: self.print_verbose("Inside Max Budget Limiter Pre-Call Hook") cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id" - user_row = cache.get_cache(cache_key) + user_row = await cache.async_get_cache( + cache_key, parent_otel_span=user_api_key_dict.parent_otel_span + ) if user_row is None: # value not yet cached return max_budget = user_row["max_budget"] diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 6089bf0c3..4a10a0179 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -235,7 +235,7 @@ class InternalUsageCache: return await self.dual_cache.async_get_cache( key=key, local_only=local_only, - litellm_parent_otel_span=litellm_parent_otel_span, + parent_otel_span=litellm_parent_otel_span, **kwargs, ) @@ -281,7 +281,7 @@ class InternalUsageCache: key=key, value=value, local_only=local_only, - litellm_parent_otel_span=litellm_parent_otel_span, + parent_otel_span=litellm_parent_otel_span, **kwargs, ) @@ -367,7 +367,10 @@ class ProxyLogging: llm_router=llm_router ) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made - if "daily_reports" in self.slack_alerting_instance.alert_types: + if ( + self.slack_alerting_instance is not None + and "daily_reports" in self.slack_alerting_instance.alert_types + ): asyncio.create_task( self.slack_alerting_instance._run_scheduled_daily_report( llm_router=llm_router diff --git a/litellm/router.py b/litellm/router.py index 0cad565b0..5ccdbcf4a 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -25,6 +25,7 @@ import uuid from collections import defaultdict from datetime import datetime from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -50,6 +51,7 @@ from litellm._logging import verbose_router_logger from litellm.assistants.main import AssistantDeleted from litellm.caching.caching import DualCache, InMemoryCache, RedisCache from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc from litellm.router_strategy.least_busy import LeastBusyLoggingHandler @@ -124,6 +126,7 @@ from litellm.types.router import ( updateDeployment, updateLiteLLMParams, ) +from litellm.types.services import ServiceLoggerPayload, ServiceTypes from litellm.types.utils import OPENAI_RESPONSE_HEADERS from litellm.types.utils import ModelInfo as ModelMapInfo from litellm.utils import ( @@ -140,6 +143,13 @@ from litellm.utils import ( from .router_utils.pattern_match_deployments import PatternMatchRouter +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + class RoutingArgs(enum.Enum): ttl = 60 # 1min (RPM/TPM expire key) @@ -293,6 +303,8 @@ class Router: ``` """ + from litellm._service_logger import ServiceLogging + if semaphore: self.semaphore = semaphore self.set_verbose = set_verbose @@ -494,7 +506,7 @@ class Router: f"Routing context window fallbacks: {self.context_window_fallbacks}\n\n" f"Router Redis Caching={self.cache.redis_cache}\n" ) - + self.service_logger_obj = ServiceLogging() self.routing_strategy_args = routing_strategy_args self.retry_policy: Optional[RetryPolicy] = None if retry_policy is not None: @@ -762,10 +774,23 @@ class Router: request_priority = kwargs.get("priority") or self.default_priority + start_time = time.time() if request_priority is not None and isinstance(request_priority, int): response = await self.schedule_acompletion(**kwargs) else: response = await self.async_function_with_fallbacks(**kwargs) + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.ROUTER, + duration=_duration, + call_type="acompletion", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) return response except Exception as e: @@ -793,15 +818,32 @@ class Router: verbose_router_logger.debug( f"Inside _acompletion()- model: {model}; kwargs: {kwargs}" ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + start_time = time.time() deployment = await self.async_get_available_deployment( model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None), request_kwargs=kwargs, ) + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.ROUTER, + duration=_duration, + call_type="async_get_available_deployment", + start_time=start_time, + end_time=end_time, + parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), + ) + ) # debug how often this deployment picked - self._track_deployment_metrics(deployment=deployment) + + self._track_deployment_metrics( + deployment=deployment, parent_otel_span=parent_otel_span + ) self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) data = deployment["litellm_params"].copy() @@ -846,12 +888,16 @@ class Router: - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) """ await self.async_routing_strategy_pre_call_checks( - deployment=deployment, logging_obj=logging_obj + deployment=deployment, + logging_obj=logging_obj, + parent_otel_span=parent_otel_span, ) response = await _response else: await self.async_routing_strategy_pre_call_checks( - deployment=deployment, logging_obj=logging_obj + deployment=deployment, + logging_obj=logging_obj, + parent_otel_span=parent_otel_span, ) response = await _response @@ -872,7 +918,11 @@ class Router: f"litellm.acompletion(model={model_name})\033[32m 200 OK\033[0m" ) # debug how often this deployment picked - self._track_deployment_metrics(deployment=deployment, response=response) + self._track_deployment_metrics( + deployment=deployment, + response=response, + parent_otel_span=parent_otel_span, + ) return response except Exception as e: @@ -1212,6 +1262,7 @@ class Router: stream=False, **kwargs, ): + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) ### FLOW ITEM ### _request_id = str(uuid.uuid4()) item = FlowItem( @@ -1232,7 +1283,7 @@ class Router: while curr_time < end_time: _healthy_deployments, _ = await self._async_get_healthy_deployments( - model=model + model=model, parent_otel_span=parent_otel_span ) make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue id=item.request_id, @@ -1353,6 +1404,7 @@ class Router: verbose_router_logger.debug( f"Inside _image_generation()- model: {model}; kwargs: {kwargs}" ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) deployment = await self.async_get_available_deployment( model=model, messages=[{"role": "user", "content": "prompt"}], @@ -1395,11 +1447,13 @@ class Router: - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) """ await self.async_routing_strategy_pre_call_checks( - deployment=deployment + deployment=deployment, parent_otel_span=parent_otel_span ) response = await response else: - await self.async_routing_strategy_pre_call_checks(deployment=deployment) + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) response = await response self.success_calls[model_name] += 1 @@ -1465,6 +1519,7 @@ class Router: verbose_router_logger.debug( f"Inside _atranscription()- model: {model}; kwargs: {kwargs}" ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) deployment = await self.async_get_available_deployment( model=model, messages=[{"role": "user", "content": "prompt"}], @@ -1505,11 +1560,13 @@ class Router: - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) """ await self.async_routing_strategy_pre_call_checks( - deployment=deployment + deployment=deployment, parent_otel_span=parent_otel_span ) response = await response else: - await self.async_routing_strategy_pre_call_checks(deployment=deployment) + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) response = await response self.success_calls[model_name] += 1 @@ -1861,6 +1918,7 @@ class Router: verbose_router_logger.debug( f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}" ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) deployment = await self.async_get_available_deployment( model=model, messages=[{"role": "user", "content": prompt}], @@ -1903,11 +1961,13 @@ class Router: - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) """ await self.async_routing_strategy_pre_call_checks( - deployment=deployment + deployment=deployment, parent_otel_span=parent_otel_span ) response = await response else: - await self.async_routing_strategy_pre_call_checks(deployment=deployment) + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) response = await response self.success_calls[model_name] += 1 @@ -1958,6 +2018,7 @@ class Router: verbose_router_logger.debug( f"Inside _aadapter_completion()- model: {model}; kwargs: {kwargs}" ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) deployment = await self.async_get_available_deployment( model=model, messages=[{"role": "user", "content": "default text"}], @@ -2000,11 +2061,13 @@ class Router: - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) """ await self.async_routing_strategy_pre_call_checks( - deployment=deployment + deployment=deployment, parent_otel_span=parent_otel_span ) response = await response # type: ignore else: - await self.async_routing_strategy_pre_call_checks(deployment=deployment) + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) response = await response # type: ignore self.success_calls[model_name] += 1 @@ -2128,6 +2191,7 @@ class Router: verbose_router_logger.debug( f"Inside _aembedding()- model: {model}; kwargs: {kwargs}" ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) deployment = await self.async_get_available_deployment( model=model, input=input, @@ -2168,11 +2232,13 @@ class Router: - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) """ await self.async_routing_strategy_pre_call_checks( - deployment=deployment + deployment=deployment, parent_otel_span=parent_otel_span ) response = await response else: - await self.async_routing_strategy_pre_call_checks(deployment=deployment) + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) response = await response self.success_calls[model_name] += 1 @@ -2223,6 +2289,7 @@ class Router: verbose_router_logger.debug( f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}" ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) deployment = await self.async_get_available_deployment( model=model, messages=[{"role": "user", "content": "files-api-fake-text"}], @@ -2273,11 +2340,13 @@ class Router: - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) """ await self.async_routing_strategy_pre_call_checks( - deployment=deployment + deployment=deployment, parent_otel_span=parent_otel_span ) response = await response # type: ignore else: - await self.async_routing_strategy_pre_call_checks(deployment=deployment) + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) response = await response # type: ignore self.success_calls[model_name] += 1 @@ -2327,6 +2396,7 @@ class Router: verbose_router_logger.debug( f"Inside _acreate_batch()- model: {model}; kwargs: {kwargs}" ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) deployment = await self.async_get_available_deployment( model=model, messages=[{"role": "user", "content": "files-api-fake-text"}], @@ -2389,11 +2459,13 @@ class Router: - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) """ await self.async_routing_strategy_pre_call_checks( - deployment=deployment + deployment=deployment, parent_otel_span=parent_otel_span ) response = await response # type: ignore else: - await self.async_routing_strategy_pre_call_checks(deployment=deployment) + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) response = await response # type: ignore self.success_calls[model_name] += 1 @@ -2702,12 +2774,14 @@ class Router: ) return response except Exception as new_exception: + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) verbose_router_logger.error( "litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format( str(new_exception), traceback.format_exc(), await _async_get_cooldown_deployments_with_debug_info( - litellm_router_instance=self + litellm_router_instance=self, + parent_otel_span=parent_otel_span, ), ) ) @@ -2779,12 +2853,13 @@ class Router: Context_Policy_Fallbacks={content_policy_fallbacks}", ) - async def async_function_with_retries(self, *args, **kwargs): + async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915 verbose_router_logger.debug( f"Inside async function with retries: args - {args}; kwargs - {kwargs}" ) original_function = kwargs.pop("original_function") fallbacks = kwargs.pop("fallbacks", self.fallbacks) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) context_window_fallbacks = kwargs.pop( "context_window_fallbacks", self.context_window_fallbacks ) @@ -2822,6 +2897,7 @@ class Router: _healthy_deployments, _all_deployments = ( await self._async_get_healthy_deployments( model=kwargs.get("model") or "", + parent_otel_span=parent_otel_span, ) ) @@ -2879,6 +2955,7 @@ class Router: _healthy_deployments, _ = ( await self._async_get_healthy_deployments( model=_model, + parent_otel_span=parent_otel_span, ) ) else: @@ -3217,8 +3294,10 @@ class Router: if _model is None: raise e # re-raise error, if model can't be determined for loadbalancing ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) _healthy_deployments, _all_deployments = self._get_healthy_deployments( model=_model, + parent_otel_span=parent_otel_span, ) # raises an exception if this error should not be retries @@ -3260,8 +3339,10 @@ class Router: if _model is None: raise e # re-raise error, if model can't be determined for loadbalancing + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) _healthy_deployments, _ = self._get_healthy_deployments( model=_model, + parent_otel_span=parent_otel_span, ) remaining_retries = num_retries - current_attempt _timeout = self._time_to_sleep_before_retry( @@ -3323,9 +3404,13 @@ class Router: # ------------ # update cache + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) ## TPM await self.cache.async_increment_cache( - key=tpm_key, value=total_tokens, ttl=RoutingArgs.ttl.value + key=tpm_key, + value=total_tokens, + parent_otel_span=parent_otel_span, + ttl=RoutingArgs.ttl.value, ) increment_deployment_successes_for_current_minute( @@ -3474,7 +3559,9 @@ class Router: except Exception as e: raise e - def _update_usage(self, deployment_id: str) -> int: + def _update_usage( + self, deployment_id: str, parent_otel_span: Optional[Span] + ) -> int: """ Update deployment rpm for that minute @@ -3483,7 +3570,9 @@ class Router: """ rpm_key = deployment_id - request_count = self.cache.get_cache(key=rpm_key, local_only=True) + request_count = self.cache.get_cache( + key=rpm_key, parent_otel_span=parent_otel_span, local_only=True + ) if request_count is None: request_count = 1 self.cache.set_cache( @@ -3591,7 +3680,7 @@ class Router: ) return False - def _get_healthy_deployments(self, model: str): + def _get_healthy_deployments(self, model: str, parent_otel_span: Optional[Span]): _all_deployments: list = [] try: _, _all_deployments = self._common_checks_available_deployment( # type: ignore @@ -3602,7 +3691,9 @@ class Router: except Exception: pass - unhealthy_deployments = _get_cooldown_deployments(litellm_router_instance=self) + unhealthy_deployments = _get_cooldown_deployments( + litellm_router_instance=self, parent_otel_span=parent_otel_span + ) healthy_deployments: list = [] for deployment in _all_deployments: if deployment["model_info"]["id"] in unhealthy_deployments: @@ -3613,7 +3704,7 @@ class Router: return healthy_deployments, _all_deployments async def _async_get_healthy_deployments( - self, model: str + self, model: str, parent_otel_span: Optional[Span] ) -> Tuple[List[Dict], List[Dict]]: """ Returns Tuple of: @@ -3632,7 +3723,7 @@ class Router: pass unhealthy_deployments = await _async_get_cooldown_deployments( - litellm_router_instance=self + litellm_router_instance=self, parent_otel_span=parent_otel_span ) healthy_deployments: list = [] for deployment in _all_deployments: @@ -3659,7 +3750,10 @@ class Router: _callback.pre_call_check(deployment) async def async_routing_strategy_pre_call_checks( - self, deployment: dict, logging_obj: Optional[LiteLLMLogging] = None + self, + deployment: dict, + parent_otel_span: Optional[Span], + logging_obj: Optional[LiteLLMLogging] = None, ): """ For usage-based-routing-v2, enables running rpm checks before the call is made, inside the semaphore. @@ -3675,7 +3769,7 @@ class Router: for _callback in litellm.callbacks: if isinstance(_callback, CustomLogger): try: - await _callback.async_pre_call_check(deployment) + await _callback.async_pre_call_check(deployment, parent_otel_span) except litellm.RateLimitError as e: ## LOG FAILURE EVENT if logging_obj is not None: @@ -4646,14 +4740,19 @@ class Router: The appropriate client based on the given client_type and kwargs. """ model_id = deployment["model_info"]["id"] + parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(kwargs) if client_type == "max_parallel_requests": cache_key = "{}_max_parallel_requests_client".format(model_id) - client = self.cache.get_cache(key=cache_key, local_only=True) + client = self.cache.get_cache( + key=cache_key, local_only=True, parent_otel_span=parent_otel_span + ) return client elif client_type == "async": if kwargs.get("stream") is True: cache_key = f"{model_id}_stream_async_client" - client = self.cache.get_cache(key=cache_key, local_only=True) + client = self.cache.get_cache( + key=cache_key, local_only=True, parent_otel_span=parent_otel_span + ) if client is None: """ Re-initialize the client @@ -4661,11 +4760,17 @@ class Router: InitalizeOpenAISDKClient.set_client( litellm_router_instance=self, model=deployment ) - client = self.cache.get_cache(key=cache_key, local_only=True) + client = self.cache.get_cache( + key=cache_key, + local_only=True, + parent_otel_span=parent_otel_span, + ) return client else: cache_key = f"{model_id}_async_client" - client = self.cache.get_cache(key=cache_key, local_only=True) + client = self.cache.get_cache( + key=cache_key, local_only=True, parent_otel_span=parent_otel_span + ) if client is None: """ Re-initialize the client @@ -4673,12 +4778,18 @@ class Router: InitalizeOpenAISDKClient.set_client( litellm_router_instance=self, model=deployment ) - client = self.cache.get_cache(key=cache_key, local_only=True) + client = self.cache.get_cache( + key=cache_key, + local_only=True, + parent_otel_span=parent_otel_span, + ) return client else: if kwargs.get("stream") is True: cache_key = f"{model_id}_stream_client" - client = self.cache.get_cache(key=cache_key) + client = self.cache.get_cache( + key=cache_key, parent_otel_span=parent_otel_span + ) if client is None: """ Re-initialize the client @@ -4686,11 +4797,15 @@ class Router: InitalizeOpenAISDKClient.set_client( litellm_router_instance=self, model=deployment ) - client = self.cache.get_cache(key=cache_key) + client = self.cache.get_cache( + key=cache_key, parent_otel_span=parent_otel_span + ) return client else: cache_key = f"{model_id}_client" - client = self.cache.get_cache(key=cache_key) + client = self.cache.get_cache( + key=cache_key, parent_otel_span=parent_otel_span + ) if client is None: """ Re-initialize the client @@ -4698,7 +4813,9 @@ class Router: InitalizeOpenAISDKClient.set_client( litellm_router_instance=self, model=deployment ) - client = self.cache.get_cache(key=cache_key) + client = self.cache.get_cache( + key=cache_key, parent_otel_span=parent_otel_span + ) return client def _pre_call_checks( # noqa: PLR0915 @@ -4738,13 +4855,17 @@ class Router: _context_window_error = False _potential_error_str = "" _rate_limit_error = False + parent_otel_span = _get_parent_otel_span_from_kwargs(request_kwargs) ## get model group RPM ## dt = get_utc_datetime() current_minute = dt.strftime("%H-%M") rpm_key = f"{model}:rpm:{current_minute}" model_group_cache = ( - self.cache.get_cache(key=rpm_key, local_only=True) or {} + self.cache.get_cache( + key=rpm_key, local_only=True, parent_otel_span=parent_otel_span + ) + or {} ) # check the in-memory cache used by lowest_latency and usage-based routing. Only check the local cache. for idx, deployment in enumerate(_returned_deployments): # see if we have the info for this model @@ -4783,7 +4904,10 @@ class Router: ## RPM CHECK ## ### get local router cache ### current_request_cache_local = ( - self.cache.get_cache(key=model_id, local_only=True) or 0 + self.cache.get_cache( + key=model_id, local_only=True, parent_otel_span=parent_otel_span + ) + or 0 ) ### get usage based cache ### if ( @@ -5002,6 +5126,7 @@ class Router: self.routing_strategy != "usage-based-routing-v2" and self.routing_strategy != "simple-shuffle" and self.routing_strategy != "cost-based-routing" + and self.routing_strategy != "latency-based-routing" ): # prevent regressions for other routing strategies, that don't have async get available deployments implemented. return self.get_available_deployment( model=model, @@ -5011,6 +5136,7 @@ class Router: request_kwargs=request_kwargs, ) try: + parent_otel_span = _get_parent_otel_span_from_kwargs(request_kwargs) model, healthy_deployments = self._common_checks_available_deployment( model=model, messages=messages, @@ -5021,7 +5147,7 @@ class Router: return healthy_deployments cooldown_deployments = await _async_get_cooldown_deployments( - litellm_router_instance=self + litellm_router_instance=self, parent_otel_span=parent_otel_span ) verbose_router_logger.debug( f"async cooldown deployments: {cooldown_deployments}" @@ -5059,16 +5185,18 @@ class Router: _allowed_model_region = "n/a" model_ids = self.get_model_ids(model_name=model) _cooldown_time = self.cooldown_cache.get_min_cooldown( - model_ids=model_ids + model_ids=model_ids, parent_otel_span=parent_otel_span + ) + _cooldown_list = _get_cooldown_deployments( + litellm_router_instance=self, parent_otel_span=parent_otel_span ) - _cooldown_list = _get_cooldown_deployments(litellm_router_instance=self) raise RouterRateLimitError( model=model, cooldown_time=_cooldown_time, enable_pre_call_checks=self.enable_pre_call_checks, cooldown_list=_cooldown_list, ) - + start_time = time.time() if ( self.routing_strategy == "usage-based-routing-v2" and self.lowesttpm_logger_v2 is not None @@ -5093,6 +5221,19 @@ class Router: input=input, ) ) + elif ( + self.routing_strategy == "latency-based-routing" + and self.lowestlatency_logger is not None + ): + deployment = ( + await self.lowestlatency_logger.async_get_available_deployments( + model_group=model, + healthy_deployments=healthy_deployments, # type: ignore + messages=messages, + input=input, + request_kwargs=request_kwargs, + ) + ) elif self.routing_strategy == "simple-shuffle": return simple_shuffle( llm_router_instance=self, @@ -5107,9 +5248,11 @@ class Router: ) model_ids = self.get_model_ids(model_name=model) _cooldown_time = self.cooldown_cache.get_min_cooldown( - model_ids=model_ids + model_ids=model_ids, parent_otel_span=parent_otel_span + ) + _cooldown_list = _get_cooldown_deployments( + litellm_router_instance=self, parent_otel_span=parent_otel_span ) - _cooldown_list = _get_cooldown_deployments(litellm_router_instance=self) raise RouterRateLimitError( model=model, cooldown_time=_cooldown_time, @@ -5120,6 +5263,19 @@ class Router: f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment)} for model: {model}" ) + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.service_logger_obj.async_service_success_hook( + service=ServiceTypes.ROUTER, + duration=_duration, + call_type=".async_get_available_deployments", + parent_otel_span=parent_otel_span, + start_time=start_time, + end_time=end_time, + ) + ) + return deployment except Exception as e: traceback_exception = traceback.format_exc() @@ -5163,7 +5319,12 @@ class Router: if isinstance(healthy_deployments, dict): return healthy_deployments - cooldown_deployments = _get_cooldown_deployments(litellm_router_instance=self) + parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs( + request_kwargs + ) + cooldown_deployments = _get_cooldown_deployments( + litellm_router_instance=self, parent_otel_span=parent_otel_span + ) healthy_deployments = self._filter_cooldown_deployments( healthy_deployments=healthy_deployments, cooldown_deployments=cooldown_deployments, @@ -5180,8 +5341,12 @@ class Router: if len(healthy_deployments) == 0: model_ids = self.get_model_ids(model_name=model) - _cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids) - _cooldown_list = _get_cooldown_deployments(litellm_router_instance=self) + _cooldown_time = self.cooldown_cache.get_min_cooldown( + model_ids=model_ids, parent_otel_span=parent_otel_span + ) + _cooldown_list = _get_cooldown_deployments( + litellm_router_instance=self, parent_otel_span=parent_otel_span + ) raise RouterRateLimitError( model=model, cooldown_time=_cooldown_time, @@ -5238,8 +5403,12 @@ class Router: f"get_available_deployment for model: {model}, No deployment available" ) model_ids = self.get_model_ids(model_name=model) - _cooldown_time = self.cooldown_cache.get_min_cooldown(model_ids=model_ids) - _cooldown_list = _get_cooldown_deployments(litellm_router_instance=self) + _cooldown_time = self.cooldown_cache.get_min_cooldown( + model_ids=model_ids, parent_otel_span=parent_otel_span + ) + _cooldown_list = _get_cooldown_deployments( + litellm_router_instance=self, parent_otel_span=parent_otel_span + ) raise RouterRateLimitError( model=model, cooldown_time=_cooldown_time, @@ -5278,7 +5447,9 @@ class Router: healthy_deployments.remove(deployment) return healthy_deployments - def _track_deployment_metrics(self, deployment, response=None): + def _track_deployment_metrics( + self, deployment, parent_otel_span: Optional[Span], response=None + ): """ Tracks successful requests rpm usage. """ @@ -5288,7 +5459,9 @@ class Router: # update self.deployment_stats if model_id is not None: - self._update_usage(model_id) # update in-memory cache for tracking + self._update_usage( + model_id, parent_otel_span + ) # update in-memory cache for tracking except Exception as e: verbose_router_logger.error(f"Error in _track_deployment_metrics: {str(e)}") diff --git a/litellm/router_strategy/lowest_latency.py b/litellm/router_strategy/lowest_latency.py index fc47d64c7..287e60146 100644 --- a/litellm/router_strategy/lowest_latency.py +++ b/litellm/router_strategy/lowest_latency.py @@ -3,7 +3,7 @@ import random import traceback from datetime import datetime, timedelta -from typing import Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from pydantic import BaseModel @@ -11,6 +11,14 @@ import litellm from litellm import ModelResponse, token_counter, verbose_logger from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any class LiteLLMBase(BaseModel): @@ -115,8 +123,13 @@ class LowestLatencyLoggingHandler(CustomLogger): # ------------ # Update usage # ------------ - - request_count_dict = self.router_cache.get_cache(key=latency_key) or {} + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + request_count_dict = ( + self.router_cache.get_cache( + key=latency_key, parent_otel_span=parent_otel_span + ) + or {} + ) if id not in request_count_dict: request_count_dict[id] = {} @@ -213,7 +226,7 @@ class LowestLatencyLoggingHandler(CustomLogger): """ latency_key = f"{model_group}_map" request_count_dict = ( - self.router_cache.get_cache(key=latency_key) or {} + await self.router_cache.async_get_cache(key=latency_key) or {} ) if id not in request_count_dict: @@ -316,8 +329,15 @@ class LowestLatencyLoggingHandler(CustomLogger): # ------------ # Update usage # ------------ - - request_count_dict = self.router_cache.get_cache(key=latency_key) or {} + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + request_count_dict = ( + await self.router_cache.async_get_cache( + key=latency_key, + parent_otel_span=parent_otel_span, + local_only=True, + ) + or {} + ) if id not in request_count_dict: request_count_dict[id] = {} @@ -379,26 +399,21 @@ class LowestLatencyLoggingHandler(CustomLogger): ) pass - def get_available_deployments( # noqa: PLR0915 + def _get_available_deployments( # noqa: PLR0915 self, model_group: str, healthy_deployments: list, messages: Optional[List[Dict[str, str]]] = None, input: Optional[Union[str, List]] = None, request_kwargs: Optional[Dict] = None, + request_count_dict: Optional[Dict] = None, ): - """ - Returns a deployment with the lowest latency - """ - # get list of potential deployments - latency_key = f"{model_group}_map" - _latency_per_deployment = {} - - request_count_dict = self.router_cache.get_cache(key=latency_key) or {} + """Common logic for both sync and async get_available_deployments""" # ----------------------- # Find lowest used model # ---------------------- + _latency_per_deployment = {} lowest_latency = float("inf") current_date = datetime.now().strftime("%Y-%m-%d") @@ -428,8 +443,8 @@ class LowestLatencyLoggingHandler(CustomLogger): # randomly sample from all_deployments, incase all deployments have latency=0.0 _items = all_deployments.items() - all_deployments = random.sample(list(_items), len(_items)) - all_deployments = dict(all_deployments) + _all_deployments = random.sample(list(_items), len(_items)) + all_deployments = dict(_all_deployments) ### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits potential_deployments = [] @@ -525,3 +540,66 @@ class LowestLatencyLoggingHandler(CustomLogger): "_latency_per_deployment" ] = _latency_per_deployment return deployment + + async def async_get_available_deployments( + self, + model_group: str, + healthy_deployments: list, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + request_kwargs: Optional[Dict] = None, + ): + # get list of potential deployments + latency_key = f"{model_group}_map" + + parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs( + request_kwargs + ) + request_count_dict = ( + await self.router_cache.async_get_cache( + key=latency_key, parent_otel_span=parent_otel_span + ) + or {} + ) + + return self._get_available_deployments( + model_group, + healthy_deployments, + messages, + input, + request_kwargs, + request_count_dict, + ) + + def get_available_deployments( + self, + model_group: str, + healthy_deployments: list, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None, + request_kwargs: Optional[Dict] = None, + ): + """ + Returns a deployment with the lowest latency + """ + # get list of potential deployments + latency_key = f"{model_group}_map" + + parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs( + request_kwargs + ) + request_count_dict = ( + self.router_cache.get_cache( + key=latency_key, parent_otel_span=parent_otel_span + ) + or {} + ) + + return self._get_available_deployments( + model_group, + healthy_deployments, + messages, + input, + request_kwargs, + request_count_dict, + ) diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index 2c62b6c7b..17ff0cc09 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -2,7 +2,7 @@ # identifies lowest tpm deployment import random import traceback -from typing import Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import httpx from pydantic import BaseModel @@ -12,9 +12,17 @@ from litellm import token_counter from litellm._logging import verbose_logger, verbose_router_logger from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs from litellm.types.router import RouterErrors from litellm.utils import get_utc_datetime, print_verbose +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + class LiteLLMBase(BaseModel): """ @@ -136,7 +144,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger): raise e return deployment # don't fail calls if eg. redis fails to connect - async def async_pre_call_check(self, deployment: Dict) -> Optional[Dict]: + async def async_pre_call_check( + self, deployment: Dict, parent_otel_span: Optional[Span] + ) -> Optional[Dict]: """ Pre-call check + update model rpm - Used inside semaphore @@ -192,7 +202,10 @@ class LowestTPMLoggingHandler_v2(CustomLogger): else: # if local result below limit, check redis ## prevent unnecessary redis checks result = await self.router_cache.async_increment_cache( - key=rpm_key, value=1, ttl=self.routing_args.ttl + key=rpm_key, + value=1, + ttl=self.routing_args.ttl, + parent_otel_span=parent_otel_span, ) if result is not None and result > deployment_rpm: raise litellm.RateLimitError( @@ -301,10 +314,13 @@ class LowestTPMLoggingHandler_v2(CustomLogger): # Update usage # ------------ # update cache - + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) ## TPM await self.router_cache.async_increment_cache( - key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl + key=tpm_key, + value=total_tokens, + ttl=self.routing_args.ttl, + parent_otel_span=parent_otel_span, ) ### TESTING ### @@ -547,6 +563,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): healthy_deployments: list, messages: Optional[List[Dict[str, str]]] = None, input: Optional[Union[str, List]] = None, + parent_otel_span: Optional[Span] = None, ): """ Returns a deployment with the lowest TPM/RPM usage. @@ -572,10 +589,10 @@ class LowestTPMLoggingHandler_v2(CustomLogger): rpm_keys.append(rpm_key) tpm_values = self.router_cache.batch_get_cache( - keys=tpm_keys + keys=tpm_keys, parent_otel_span=parent_otel_span ) # [1, 2, None, ..] rpm_values = self.router_cache.batch_get_cache( - keys=rpm_keys + keys=rpm_keys, parent_otel_span=parent_otel_span ) # [1, 2, None, ..] deployment = self._common_checks_available_deployment( diff --git a/litellm/router_utils/cooldown_cache.py b/litellm/router_utils/cooldown_cache.py index e30b4a605..792d91811 100644 --- a/litellm/router_utils/cooldown_cache.py +++ b/litellm/router_utils/cooldown_cache.py @@ -4,11 +4,18 @@ Wrapper around router cache. Meant to handle model cooldown logic import json import time -from typing import List, Optional, Tuple, TypedDict +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict from litellm import verbose_logger from litellm.caching.caching import DualCache +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + class CooldownCacheValue(TypedDict): exception_received: str @@ -77,13 +84,18 @@ class CooldownCache: raise e async def async_get_active_cooldowns( - self, model_ids: List[str] + self, model_ids: List[str], parent_otel_span: Optional[Span] ) -> List[Tuple[str, CooldownCacheValue]]: # Generate the keys for the deployments keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids] # Retrieve the values for the keys using mget - results = await self.cache.async_batch_get_cache(keys=keys) or [] + results = ( + await self.cache.async_batch_get_cache( + keys=keys, parent_otel_span=parent_otel_span + ) + or [] + ) active_cooldowns = [] # Process the results @@ -95,13 +107,15 @@ class CooldownCache: return active_cooldowns def get_active_cooldowns( - self, model_ids: List[str] + self, model_ids: List[str], parent_otel_span: Optional[Span] ) -> List[Tuple[str, CooldownCacheValue]]: # Generate the keys for the deployments keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids] - # Retrieve the values for the keys using mget - results = self.cache.batch_get_cache(keys=keys) or [] + results = ( + self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span) + or [] + ) active_cooldowns = [] # Process the results @@ -112,14 +126,19 @@ class CooldownCache: return active_cooldowns - def get_min_cooldown(self, model_ids: List[str]) -> float: + def get_min_cooldown( + self, model_ids: List[str], parent_otel_span: Optional[Span] + ) -> float: """Return min cooldown time required for a group of model id's.""" # Generate the keys for the deployments keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids] # Retrieve the values for the keys using mget - results = self.cache.batch_get_cache(keys=keys) or [] + results = ( + self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span) + or [] + ) min_cooldown_time: Optional[float] = None # Process the results diff --git a/litellm/router_utils/cooldown_handlers.py b/litellm/router_utils/cooldown_handlers.py index b5c5de2fe..42864d986 100644 --- a/litellm/router_utils/cooldown_handlers.py +++ b/litellm/router_utils/cooldown_handlers.py @@ -20,12 +20,15 @@ from .router_callbacks.track_deployment_metrics import ( ) if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + from litellm.router import Router as _Router LitellmRouter = _Router + Span = _Span else: LitellmRouter = Any - + Span = Any DEFAULT_FAILURE_THRESHOLD_PERCENT = ( 0.5 # default cooldown a deployment if 50% of requests fail in a given minute ) @@ -207,6 +210,7 @@ def _set_cooldown_deployments( async def _async_get_cooldown_deployments( litellm_router_instance: LitellmRouter, + parent_otel_span: Optional[Span], ) -> List[str]: """ Async implementation of '_get_cooldown_deployments' @@ -214,7 +218,8 @@ async def _async_get_cooldown_deployments( model_ids = litellm_router_instance.get_model_ids() cooldown_models = ( await litellm_router_instance.cooldown_cache.async_get_active_cooldowns( - model_ids=model_ids + model_ids=model_ids, + parent_otel_span=parent_otel_span, ) ) @@ -233,6 +238,7 @@ async def _async_get_cooldown_deployments( async def _async_get_cooldown_deployments_with_debug_info( litellm_router_instance: LitellmRouter, + parent_otel_span: Optional[Span], ) -> List[tuple]: """ Async implementation of '_get_cooldown_deployments' @@ -240,7 +246,7 @@ async def _async_get_cooldown_deployments_with_debug_info( model_ids = litellm_router_instance.get_model_ids() cooldown_models = ( await litellm_router_instance.cooldown_cache.async_get_active_cooldowns( - model_ids=model_ids + model_ids=model_ids, parent_otel_span=parent_otel_span ) ) @@ -248,7 +254,9 @@ async def _async_get_cooldown_deployments_with_debug_info( return cooldown_models -def _get_cooldown_deployments(litellm_router_instance: LitellmRouter) -> List[str]: +def _get_cooldown_deployments( + litellm_router_instance: LitellmRouter, parent_otel_span: Optional[Span] +) -> List[str]: """ Get the list of models being cooled down for this minute """ @@ -258,8 +266,9 @@ def _get_cooldown_deployments(litellm_router_instance: LitellmRouter) -> List[st # Return cooldown models # ---------------------- model_ids = litellm_router_instance.get_model_ids() + cooldown_models = litellm_router_instance.cooldown_cache.get_active_cooldowns( - model_ids=model_ids + model_ids=model_ids, parent_otel_span=parent_otel_span ) cached_value_deployment_ids = [] diff --git a/litellm/types/services.py b/litellm/types/services.py index 62046ff44..08259c741 100644 --- a/litellm/types/services.py +++ b/litellm/types/services.py @@ -14,6 +14,7 @@ class ServiceTypes(str, enum.Enum): DB = "postgres" BATCH_WRITE_TO_DB = "batch_write_to_db" LITELLM = "self" + ROUTER = "router" class ServiceLoggerPayload(BaseModel): diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py index f56079aa7..b19585430 100644 --- a/tests/local_testing/test_caching.py +++ b/tests/local_testing/test_caching.py @@ -83,7 +83,9 @@ def test_dual_cache_batch_get_cache(): in_memory_cache.set_cache(key="test_value", value="hello world") - result = dual_cache.batch_get_cache(keys=["test_value", "test_value_2"]) + result = dual_cache.batch_get_cache( + keys=["test_value", "test_value_2"], parent_otel_span=None + ) assert result[0] == "hello world" assert result[1] == None diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index a6316233a..d360d7317 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -2447,11 +2447,11 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode): if sync_mode: cooldown_deployments = _get_cooldown_deployments( - litellm_router_instance=router + litellm_router_instance=router, parent_otel_span=None ) else: cooldown_deployments = await _async_get_cooldown_deployments( - litellm_router_instance=router + litellm_router_instance=router, parent_otel_span=None ) print( "Cooldown deployments - {}\n{}".format( diff --git a/tests/local_testing/test_router_cooldowns.py b/tests/local_testing/test_router_cooldowns.py index deeec6c29..774b36e2a 100644 --- a/tests/local_testing/test_router_cooldowns.py +++ b/tests/local_testing/test_router_cooldowns.py @@ -242,12 +242,12 @@ async def test_single_deployment_no_cooldowns_test_prod_mock_completion_calls(): pass cooldown_list = await _async_get_cooldown_deployments( - litellm_router_instance=router + litellm_router_instance=router, parent_otel_span=None ) assert len(cooldown_list) == 0 healthy_deployments, _ = await router._async_get_healthy_deployments( - model="gpt-3.5-turbo" + model="gpt-3.5-turbo", parent_otel_span=None ) print("healthy_deployments: ", healthy_deployments) @@ -351,7 +351,7 @@ async def test_high_traffic_cooldowns_all_healthy_deployments(): print("model_stats: ", model_stats) cooldown_list = await _async_get_cooldown_deployments( - litellm_router_instance=router + litellm_router_instance=router, parent_otel_span=None ) assert len(cooldown_list) == 0 @@ -449,7 +449,7 @@ async def test_high_traffic_cooldowns_one_bad_deployment(): print("model_stats: ", model_stats) cooldown_list = await _async_get_cooldown_deployments( - litellm_router_instance=router + litellm_router_instance=router, parent_otel_span=None ) assert len(cooldown_list) == 1 @@ -550,7 +550,7 @@ async def test_high_traffic_cooldowns_one_rate_limited_deployment(): print("model_stats: ", model_stats) cooldown_list = await _async_get_cooldown_deployments( - litellm_router_instance=router + litellm_router_instance=router, parent_otel_span=None ) assert len(cooldown_list) == 1 diff --git a/tests/router_unit_tests/test_router_helper_utils.py b/tests/router_unit_tests/test_router_helper_utils.py index 78e322764..0231e199f 100644 --- a/tests/router_unit_tests/test_router_helper_utils.py +++ b/tests/router_unit_tests/test_router_helper_utils.py @@ -440,12 +440,12 @@ def test_update_usage(model_list): ) deployment_id = deployment["model_info"]["id"] request_count = router._update_usage( - deployment_id=deployment_id, + deployment_id=deployment_id, parent_otel_span=None ) assert request_count == 1 request_count = router._update_usage( - deployment_id=deployment_id, + deployment_id=deployment_id, parent_otel_span=None ) assert request_count == 2 @@ -482,7 +482,9 @@ def test_should_raise_content_policy_error(model_list, finish_reason, expected_e def test_get_healthy_deployments(model_list): """Test if the '_get_healthy_deployments' function is working correctly""" router = Router(model_list=model_list) - deployments = router._get_healthy_deployments(model="gpt-3.5-turbo") + deployments = router._get_healthy_deployments( + model="gpt-3.5-turbo", parent_otel_span=None + ) assert len(deployments) > 0 @@ -756,6 +758,7 @@ def test_track_deployment_metrics(model_list): model="gpt-3.5-turbo", usage={"total_tokens": 100}, ), + parent_otel_span=None, )