(litellm SDK perf improvements) - handle cases when unable to lookup model in model cost map (#7750)

* use lru cache wrapper

* use lru_cache_wrapper for _cached_get_model_info_helper

* fix _get_traceback_str_for_error

* huggingface/mistralai/Mistral-7B-Instruct-v0.3
This commit is contained in:
Ishaan Jaff 2025-01-13 19:58:46 -08:00 committed by GitHub
parent c8ac61f117
commit d88f01d518
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 42 additions and 2 deletions

View file

@ -0,0 +1,30 @@
from functools import lru_cache
from typing import Callable, Optional, TypeVar
T = TypeVar("T")
def lru_cache_wrapper(
maxsize: Optional[int] = None,
) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""
Wrapper for lru_cache that caches success and exceptions
"""
def decorator(f: Callable[..., T]) -> Callable[..., T]:
@lru_cache(maxsize=maxsize)
def wrapper(*args, **kwargs):
try:
return ("success", f(*args, **kwargs))
except Exception as e:
return ("error", e)
def wrapped(*args, **kwargs):
result = wrapper(*args, **kwargs)
if result[0] == "error":
raise result[1]
return result[1]
return wrapped
return decorator

View file

@ -12,6 +12,7 @@ import time
import traceback import traceback
import uuid import uuid
from datetime import datetime as dt_object from datetime import datetime as dt_object
from functools import lru_cache
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast
from pydantic import BaseModel from pydantic import BaseModel
@ -835,7 +836,7 @@ class Logging(LiteLLMLoggingBaseClass):
except Exception as e: # error calculating cost except Exception as e: # error calculating cost
debug_info = StandardLoggingModelCostFailureDebugInformation( debug_info = StandardLoggingModelCostFailureDebugInformation(
error_str=str(e), error_str=str(e),
traceback_str=traceback.format_exc(), traceback_str=_get_traceback_str_for_error(str(e)),
model=response_cost_calculator_kwargs["model"], model=response_cost_calculator_kwargs["model"],
cache_hit=response_cost_calculator_kwargs["cache_hit"], cache_hit=response_cost_calculator_kwargs["cache_hit"],
custom_llm_provider=response_cost_calculator_kwargs[ custom_llm_provider=response_cost_calculator_kwargs[
@ -3320,3 +3321,11 @@ def modify_integration(integration_name, integration_params):
if integration_name == "supabase": if integration_name == "supabase":
if "table_name" in integration_params: if "table_name" in integration_params:
Supabase.supabase_table_name = integration_params["table_name"] Supabase.supabase_table_name = integration_params["table_name"]
@lru_cache(maxsize=16)
def _get_traceback_str_for_error(error_str: str) -> str:
"""
function wrapped with lru_cache to limit the number of times `traceback.format_exc()` is called
"""
return traceback.format_exc()

View file

@ -57,6 +57,7 @@ import litellm._service_logger # for storing API inputs, outputs, and metadata
import litellm.litellm_core_utils import litellm.litellm_core_utils
import litellm.litellm_core_utils.audio_utils.utils import litellm.litellm_core_utils.audio_utils.utils
import litellm.litellm_core_utils.json_validation_rule import litellm.litellm_core_utils.json_validation_rule
from litellm.caching._internal_lru_cache import lru_cache_wrapper
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
@ -4013,7 +4014,7 @@ def _get_max_position_embeddings(model_name: str) -> Optional[int]:
return None return None
@lru_cache(maxsize=16) @lru_cache_wrapper(maxsize=16)
def _cached_get_model_info_helper( def _cached_get_model_info_helper(
model: str, custom_llm_provider: Optional[str] model: str, custom_llm_provider: Optional[str]
) -> ModelInfoBase: ) -> ModelInfoBase: