(litellm SDK perf improvements) - handle cases when unable to lookup model in model cost map (#7750)

* use lru cache wrapper

* use lru_cache_wrapper for _cached_get_model_info_helper

* fix _get_traceback_str_for_error

* huggingface/mistralai/Mistral-7B-Instruct-v0.3
This commit is contained in:
Ishaan Jaff 2025-01-13 19:58:46 -08:00 committed by GitHub
parent c8ac61f117
commit d88f01d518
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 42 additions and 2 deletions

View file

@ -0,0 +1,30 @@
from functools import lru_cache
from typing import Callable, Optional, TypeVar
T = TypeVar("T")
def lru_cache_wrapper(
maxsize: Optional[int] = None,
) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""
Wrapper for lru_cache that caches success and exceptions
"""
def decorator(f: Callable[..., T]) -> Callable[..., T]:
@lru_cache(maxsize=maxsize)
def wrapper(*args, **kwargs):
try:
return ("success", f(*args, **kwargs))
except Exception as e:
return ("error", e)
def wrapped(*args, **kwargs):
result = wrapper(*args, **kwargs)
if result[0] == "error":
raise result[1]
return result[1]
return wrapped
return decorator

View file

@ -12,6 +12,7 @@ import time
import traceback
import uuid
from datetime import datetime as dt_object
from functools import lru_cache
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast
from pydantic import BaseModel
@ -835,7 +836,7 @@ class Logging(LiteLLMLoggingBaseClass):
except Exception as e: # error calculating cost
debug_info = StandardLoggingModelCostFailureDebugInformation(
error_str=str(e),
traceback_str=traceback.format_exc(),
traceback_str=_get_traceback_str_for_error(str(e)),
model=response_cost_calculator_kwargs["model"],
cache_hit=response_cost_calculator_kwargs["cache_hit"],
custom_llm_provider=response_cost_calculator_kwargs[
@ -3320,3 +3321,11 @@ def modify_integration(integration_name, integration_params):
if integration_name == "supabase":
if "table_name" in integration_params:
Supabase.supabase_table_name = integration_params["table_name"]
@lru_cache(maxsize=16)
def _get_traceback_str_for_error(error_str: str) -> str:
"""
function wrapped with lru_cache to limit the number of times `traceback.format_exc()` is called
"""
return traceback.format_exc()

View file

@ -57,6 +57,7 @@ import litellm._service_logger # for storing API inputs, outputs, and metadata
import litellm.litellm_core_utils
import litellm.litellm_core_utils.audio_utils.utils
import litellm.litellm_core_utils.json_validation_rule
from litellm.caching._internal_lru_cache import lru_cache_wrapper
from litellm.caching.caching import DualCache
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
from litellm.integrations.custom_logger import CustomLogger
@ -4013,7 +4014,7 @@ def _get_max_position_embeddings(model_name: str) -> Optional[int]:
return None
@lru_cache(maxsize=16)
@lru_cache_wrapper(maxsize=16)
def _cached_get_model_info_helper(
model: str, custom_llm_provider: Optional[str]
) -> ModelInfoBase: