redis otel tracing + async support for latency routing (#6452)

* docs(exception_mapping.md): add missing exception types

Fixes https://github.com/Aider-AI/aider/issues/2120#issuecomment-2438971183

* fix(main.py): register custom model pricing with specific key

Ensure custom model pricing is registered to the specific model+provider key combination

* test: make testing more robust for custom pricing

* fix(redis_cache.py): instrument otel logging for sync redis calls

ensures complete coverage for all redis cache calls

* refactor: pass parent_otel_span for redis caching calls in router

allows for more observability into what calls are causing latency issues

* test: update tests with new params

* refactor: ensure e2e otel tracing for router

* refactor(router.py): add more otel tracing acrosss router

catch all latency issues for router requests

* fix: fix linting error

* fix(router.py): fix linting error

* fix: fix test

* test: fix tests

* fix(dual_cache.py): pass ttl to redis cache

* fix: fix param
This commit is contained in:
Krish Dholakia 2024-10-28 21:52:12 -07:00 committed by GitHub
parent d9e7818e6b
commit 4f8a3fd4cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 559 additions and 147 deletions

View file

@ -3,7 +3,7 @@
import random
import traceback
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from pydantic import BaseModel
@ -11,6 +11,14 @@ import litellm
from litellm import ModelResponse, token_counter, verbose_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
else:
Span = Any
class LiteLLMBase(BaseModel):
@ -115,8 +123,13 @@ class LowestLatencyLoggingHandler(CustomLogger):
# ------------
# Update usage
# ------------
request_count_dict = self.router_cache.get_cache(key=latency_key) or {}
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
request_count_dict = (
self.router_cache.get_cache(
key=latency_key, parent_otel_span=parent_otel_span
)
or {}
)
if id not in request_count_dict:
request_count_dict[id] = {}
@ -213,7 +226,7 @@ class LowestLatencyLoggingHandler(CustomLogger):
"""
latency_key = f"{model_group}_map"
request_count_dict = (
self.router_cache.get_cache(key=latency_key) or {}
await self.router_cache.async_get_cache(key=latency_key) or {}
)
if id not in request_count_dict:
@ -316,8 +329,15 @@ class LowestLatencyLoggingHandler(CustomLogger):
# ------------
# Update usage
# ------------
request_count_dict = self.router_cache.get_cache(key=latency_key) or {}
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
request_count_dict = (
await self.router_cache.async_get_cache(
key=latency_key,
parent_otel_span=parent_otel_span,
local_only=True,
)
or {}
)
if id not in request_count_dict:
request_count_dict[id] = {}
@ -379,26 +399,21 @@ class LowestLatencyLoggingHandler(CustomLogger):
)
pass
def get_available_deployments( # noqa: PLR0915
def _get_available_deployments( # noqa: PLR0915
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
request_kwargs: Optional[Dict] = None,
request_count_dict: Optional[Dict] = None,
):
"""
Returns a deployment with the lowest latency
"""
# get list of potential deployments
latency_key = f"{model_group}_map"
_latency_per_deployment = {}
request_count_dict = self.router_cache.get_cache(key=latency_key) or {}
"""Common logic for both sync and async get_available_deployments"""
# -----------------------
# Find lowest used model
# ----------------------
_latency_per_deployment = {}
lowest_latency = float("inf")
current_date = datetime.now().strftime("%Y-%m-%d")
@ -428,8 +443,8 @@ class LowestLatencyLoggingHandler(CustomLogger):
# randomly sample from all_deployments, incase all deployments have latency=0.0
_items = all_deployments.items()
all_deployments = random.sample(list(_items), len(_items))
all_deployments = dict(all_deployments)
_all_deployments = random.sample(list(_items), len(_items))
all_deployments = dict(_all_deployments)
### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits
potential_deployments = []
@ -525,3 +540,66 @@ class LowestLatencyLoggingHandler(CustomLogger):
"_latency_per_deployment"
] = _latency_per_deployment
return deployment
async def async_get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
request_kwargs: Optional[Dict] = None,
):
# get list of potential deployments
latency_key = f"{model_group}_map"
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
request_kwargs
)
request_count_dict = (
await self.router_cache.async_get_cache(
key=latency_key, parent_otel_span=parent_otel_span
)
or {}
)
return self._get_available_deployments(
model_group,
healthy_deployments,
messages,
input,
request_kwargs,
request_count_dict,
)
def get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
request_kwargs: Optional[Dict] = None,
):
"""
Returns a deployment with the lowest latency
"""
# get list of potential deployments
latency_key = f"{model_group}_map"
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
request_kwargs
)
request_count_dict = (
self.router_cache.get_cache(
key=latency_key, parent_otel_span=parent_otel_span
)
or {}
)
return self._get_available_deployments(
model_group,
healthy_deployments,
messages,
input,
request_kwargs,
request_count_dict,
)