LiteLLM Minor Fixes & Improvements (11/12/2024) (#6705)

* fix(caching): convert arg to equivalent kwargs in llm caching handler

prevent unexpected errors

* fix(caching_handler.py): don't pass args to caching

* fix(caching): remove all *args from caching.py

* fix(caching): consistent function signatures + abc method

* test(caching_unit_tests.py): add unit tests for llm caching

ensures coverage for common caching scenarios across different implementations

* refactor(litellm_logging.py): move to using cache key from hidden params instead of regenerating one

* fix(router.py): drop redis password requirement

* fix(proxy_server.py): fix faulty slack alerting check

* fix(langfuse.py): avoid copying functions/thread lock objects in metadata

fixes metadata copy error when parent otel span in metadata

* test: update test
This commit is contained in:
Krish Dholakia 2024-11-12 22:50:51 +05:30 committed by GitHub
parent d39fd60801
commit 9160d80fa5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 525 additions and 204 deletions

View file

@ -8,6 +8,7 @@ Has 4 methods:
- async_get_cache - async_get_cache
""" """
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
@ -18,7 +19,7 @@ else:
Span = Any Span = Any
class BaseCache: class BaseCache(ABC):
def __init__(self, default_ttl: int = 60): def __init__(self, default_ttl: int = 60):
self.default_ttl = default_ttl self.default_ttl = default_ttl
@ -37,6 +38,10 @@ class BaseCache:
async def async_set_cache(self, key, value, **kwargs): async def async_set_cache(self, key, value, **kwargs):
raise NotImplementedError raise NotImplementedError
@abstractmethod
async def async_set_cache_pipeline(self, cache_list, **kwargs):
pass
def get_cache(self, key, **kwargs): def get_cache(self, key, **kwargs):
raise NotImplementedError raise NotImplementedError

View file

@ -233,19 +233,18 @@ class Cache:
if self.namespace is not None and isinstance(self.cache, RedisCache): if self.namespace is not None and isinstance(self.cache, RedisCache):
self.cache.namespace = self.namespace self.cache.namespace = self.namespace
def get_cache_key(self, *args, **kwargs) -> str: def get_cache_key(self, **kwargs) -> str:
""" """
Get the cache key for the given arguments. Get the cache key for the given arguments.
Args: Args:
*args: args to litellm.completion() or embedding()
**kwargs: kwargs to litellm.completion() or embedding() **kwargs: kwargs to litellm.completion() or embedding()
Returns: Returns:
str: The cache key generated from the arguments, or None if no cache key could be generated. str: The cache key generated from the arguments, or None if no cache key could be generated.
""" """
cache_key = "" cache_key = ""
verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs) # verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs) preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
if preset_cache_key is not None: if preset_cache_key is not None:
@ -521,7 +520,7 @@ class Cache:
return cached_response return cached_response
return cached_result return cached_result
def get_cache(self, *args, **kwargs): def get_cache(self, **kwargs):
""" """
Retrieves the cached result for the given arguments. Retrieves the cached result for the given arguments.
@ -533,13 +532,13 @@ class Cache:
The cached result if it exists, otherwise None. The cached result if it exists, otherwise None.
""" """
try: # never block execution try: # never block execution
if self.should_use_cache(*args, **kwargs) is not True: if self.should_use_cache(**kwargs) is not True:
return return
messages = kwargs.get("messages", []) messages = kwargs.get("messages", [])
if "cache_key" in kwargs: if "cache_key" in kwargs:
cache_key = kwargs["cache_key"] cache_key = kwargs["cache_key"]
else: else:
cache_key = self.get_cache_key(*args, **kwargs) cache_key = self.get_cache_key(**kwargs)
if cache_key is not None: if cache_key is not None:
cache_control_args = kwargs.get("cache", {}) cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get( max_age = cache_control_args.get(
@ -553,29 +552,28 @@ class Cache:
print_verbose(f"An exception occurred: {traceback.format_exc()}") print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None return None
async def async_get_cache(self, *args, **kwargs): async def async_get_cache(self, **kwargs):
""" """
Async get cache implementation. Async get cache implementation.
Used for embedding calls in async wrapper Used for embedding calls in async wrapper
""" """
try: # never block execution try: # never block execution
if self.should_use_cache(*args, **kwargs) is not True: if self.should_use_cache(**kwargs) is not True:
return return
kwargs.get("messages", []) kwargs.get("messages", [])
if "cache_key" in kwargs: if "cache_key" in kwargs:
cache_key = kwargs["cache_key"] cache_key = kwargs["cache_key"]
else: else:
cache_key = self.get_cache_key(*args, **kwargs) cache_key = self.get_cache_key(**kwargs)
if cache_key is not None: if cache_key is not None:
cache_control_args = kwargs.get("cache", {}) cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get( max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf")) "s-max-age", cache_control_args.get("s-maxage", float("inf"))
) )
cached_result = await self.cache.async_get_cache( cached_result = await self.cache.async_get_cache(cache_key, **kwargs)
cache_key, *args, **kwargs
)
return self._get_cache_logic( return self._get_cache_logic(
cached_result=cached_result, max_age=max_age cached_result=cached_result, max_age=max_age
) )
@ -583,7 +581,7 @@ class Cache:
print_verbose(f"An exception occurred: {traceback.format_exc()}") print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None return None
def _add_cache_logic(self, result, *args, **kwargs): def _add_cache_logic(self, result, **kwargs):
""" """
Common implementation across sync + async add_cache functions Common implementation across sync + async add_cache functions
""" """
@ -591,7 +589,7 @@ class Cache:
if "cache_key" in kwargs: if "cache_key" in kwargs:
cache_key = kwargs["cache_key"] cache_key = kwargs["cache_key"]
else: else:
cache_key = self.get_cache_key(*args, **kwargs) cache_key = self.get_cache_key(**kwargs)
if cache_key is not None: if cache_key is not None:
if isinstance(result, BaseModel): if isinstance(result, BaseModel):
result = result.model_dump_json() result = result.model_dump_json()
@ -613,7 +611,7 @@ class Cache:
except Exception as e: except Exception as e:
raise e raise e
def add_cache(self, result, *args, **kwargs): def add_cache(self, result, **kwargs):
""" """
Adds a result to the cache. Adds a result to the cache.
@ -625,41 +623,42 @@ class Cache:
None None
""" """
try: try:
if self.should_use_cache(*args, **kwargs) is not True: if self.should_use_cache(**kwargs) is not True:
return return
cache_key, cached_data, kwargs = self._add_cache_logic( cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, *args, **kwargs result=result, **kwargs
) )
self.cache.set_cache(cache_key, cached_data, **kwargs) self.cache.set_cache(cache_key, cached_data, **kwargs)
except Exception as e: except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}") verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
async def async_add_cache(self, result, *args, **kwargs): async def async_add_cache(self, result, **kwargs):
""" """
Async implementation of add_cache Async implementation of add_cache
""" """
try: try:
if self.should_use_cache(*args, **kwargs) is not True: if self.should_use_cache(**kwargs) is not True:
return return
if self.type == "redis" and self.redis_flush_size is not None: if self.type == "redis" and self.redis_flush_size is not None:
# high traffic - fill in results in memory and then flush # high traffic - fill in results in memory and then flush
await self.batch_cache_write(result, *args, **kwargs) await self.batch_cache_write(result, **kwargs)
else: else:
cache_key, cached_data, kwargs = self._add_cache_logic( cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, *args, **kwargs result=result, **kwargs
) )
await self.cache.async_set_cache(cache_key, cached_data, **kwargs) await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
except Exception as e: except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}") verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
async def async_add_cache_pipeline(self, result, *args, **kwargs): async def async_add_cache_pipeline(self, result, **kwargs):
""" """
Async implementation of add_cache for Embedding calls Async implementation of add_cache for Embedding calls
Does a bulk write, to prevent using too many clients Does a bulk write, to prevent using too many clients
""" """
try: try:
if self.should_use_cache(*args, **kwargs) is not True: if self.should_use_cache(**kwargs) is not True:
return return
# set default ttl if not set # set default ttl if not set
@ -668,29 +667,27 @@ class Cache:
cache_list = [] cache_list = []
for idx, i in enumerate(kwargs["input"]): for idx, i in enumerate(kwargs["input"]):
preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i}) preset_cache_key = self.get_cache_key(**{**kwargs, "input": i})
kwargs["cache_key"] = preset_cache_key kwargs["cache_key"] = preset_cache_key
embedding_response = result.data[idx] embedding_response = result.data[idx]
cache_key, cached_data, kwargs = self._add_cache_logic( cache_key, cached_data, kwargs = self._add_cache_logic(
result=embedding_response, result=embedding_response,
*args,
**kwargs, **kwargs,
) )
cache_list.append((cache_key, cached_data)) cache_list.append((cache_key, cached_data))
async_set_cache_pipeline = getattr(
self.cache, "async_set_cache_pipeline", None await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs)
) # if async_set_cache_pipeline:
if async_set_cache_pipeline: # await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
await async_set_cache_pipeline(cache_list=cache_list, **kwargs) # else:
else: # tasks = []
tasks = [] # for val in cache_list:
for val in cache_list: # tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs)) # await asyncio.gather(*tasks)
await asyncio.gather(*tasks)
except Exception as e: except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}") verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
def should_use_cache(self, *args, **kwargs): def should_use_cache(self, **kwargs):
""" """
Returns true if we should use the cache for LLM API calls Returns true if we should use the cache for LLM API calls
@ -708,10 +705,8 @@ class Cache:
return True return True
return False return False
async def batch_cache_write(self, result, *args, **kwargs): async def batch_cache_write(self, result, **kwargs):
cache_key, cached_data, kwargs = self._add_cache_logic( cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
result=result, *args, **kwargs
)
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs) await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
async def ping(self): async def ping(self):

View file

@ -137,7 +137,7 @@ class LLMCachingHandler:
if litellm.cache is not None and self._is_call_type_supported_by_cache( if litellm.cache is not None and self._is_call_type_supported_by_cache(
original_function=original_function original_function=original_function
): ):
print_verbose("Checking Cache") verbose_logger.debug("Checking Cache")
cached_result = await self._retrieve_from_cache( cached_result = await self._retrieve_from_cache(
call_type=call_type, call_type=call_type,
kwargs=kwargs, kwargs=kwargs,
@ -145,7 +145,7 @@ class LLMCachingHandler:
) )
if cached_result is not None and not isinstance(cached_result, list): if cached_result is not None and not isinstance(cached_result, list):
print_verbose("Cache Hit!") verbose_logger.debug("Cache Hit!")
cache_hit = True cache_hit = True
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
model, _, _, _ = litellm.get_llm_provider( model, _, _, _ = litellm.get_llm_provider(
@ -215,6 +215,7 @@ class LLMCachingHandler:
final_embedding_cached_response=final_embedding_cached_response, final_embedding_cached_response=final_embedding_cached_response,
embedding_all_elements_cache_hit=embedding_all_elements_cache_hit, embedding_all_elements_cache_hit=embedding_all_elements_cache_hit,
) )
verbose_logger.debug(f"CACHE RESULT: {cached_result}")
return CachingHandlerResponse( return CachingHandlerResponse(
cached_result=cached_result, cached_result=cached_result,
final_embedding_cached_response=final_embedding_cached_response, final_embedding_cached_response=final_embedding_cached_response,
@ -233,12 +234,19 @@ class LLMCachingHandler:
from litellm.utils import CustomStreamWrapper from litellm.utils import CustomStreamWrapper
args = args or () args = args or ()
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
self.original_function,
args,
)
)
cached_result: Optional[Any] = None cached_result: Optional[Any] = None
if litellm.cache is not None and self._is_call_type_supported_by_cache( if litellm.cache is not None and self._is_call_type_supported_by_cache(
original_function=original_function original_function=original_function
): ):
print_verbose("Checking Cache") print_verbose("Checking Cache")
cached_result = litellm.cache.get_cache(*args, **kwargs) cached_result = litellm.cache.get_cache(**new_kwargs)
if cached_result is not None: if cached_result is not None:
if "detail" in cached_result: if "detail" in cached_result:
# implies an error occurred # implies an error occurred
@ -475,14 +483,21 @@ class LLMCachingHandler:
if litellm.cache is None: if litellm.cache is None:
return None return None
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
self.original_function,
args,
)
)
cached_result: Optional[Any] = None cached_result: Optional[Any] = None
if call_type == CallTypes.aembedding.value and isinstance( if call_type == CallTypes.aembedding.value and isinstance(
kwargs["input"], list new_kwargs["input"], list
): ):
tasks = [] tasks = []
for idx, i in enumerate(kwargs["input"]): for idx, i in enumerate(new_kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key( preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i} **{**new_kwargs, "input": i}
) )
tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key)) tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key))
cached_result = await asyncio.gather(*tasks) cached_result = await asyncio.gather(*tasks)
@ -493,9 +508,9 @@ class LLMCachingHandler:
cached_result = None cached_result = None
else: else:
if litellm.cache._supports_async() is True: if litellm.cache._supports_async() is True:
cached_result = await litellm.cache.async_get_cache(*args, **kwargs) cached_result = await litellm.cache.async_get_cache(**new_kwargs)
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync] else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
cached_result = litellm.cache.get_cache(*args, **kwargs) cached_result = litellm.cache.get_cache(**new_kwargs)
return cached_result return cached_result
def _convert_cached_result_to_model_response( def _convert_cached_result_to_model_response(
@ -580,6 +595,7 @@ class LLMCachingHandler:
model_response_object=EmbeddingResponse(), model_response_object=EmbeddingResponse(),
response_type="embedding", response_type="embedding",
) )
elif ( elif (
call_type == CallTypes.arerank.value or call_type == CallTypes.rerank.value call_type == CallTypes.arerank.value or call_type == CallTypes.rerank.value
) and isinstance(cached_result, dict): ) and isinstance(cached_result, dict):
@ -603,6 +619,13 @@ class LLMCachingHandler:
response_type="audio_transcription", response_type="audio_transcription",
hidden_params=hidden_params, hidden_params=hidden_params,
) )
if (
hasattr(cached_result, "_hidden_params")
and cached_result._hidden_params is not None
and isinstance(cached_result._hidden_params, dict)
):
cached_result._hidden_params["cache_hit"] = True
return cached_result return cached_result
def _convert_cached_stream_response( def _convert_cached_stream_response(
@ -658,12 +681,19 @@ class LLMCachingHandler:
Raises: Raises:
None None
""" """
kwargs.update(convert_args_to_kwargs(result, original_function, kwargs, args))
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
original_function,
args,
)
)
if litellm.cache is None: if litellm.cache is None:
return return
# [OPTIONAL] ADD TO CACHE # [OPTIONAL] ADD TO CACHE
if self._should_store_result_in_cache( if self._should_store_result_in_cache(
original_function=original_function, kwargs=kwargs original_function=original_function, kwargs=new_kwargs
): ):
if ( if (
isinstance(result, litellm.ModelResponse) isinstance(result, litellm.ModelResponse)
@ -673,29 +703,29 @@ class LLMCachingHandler:
): ):
if ( if (
isinstance(result, EmbeddingResponse) isinstance(result, EmbeddingResponse)
and isinstance(kwargs["input"], list) and isinstance(new_kwargs["input"], list)
and litellm.cache is not None and litellm.cache is not None
and not isinstance( and not isinstance(
litellm.cache.cache, S3Cache litellm.cache.cache, S3Cache
) # s3 doesn't support bulk writing. Exclude. ) # s3 doesn't support bulk writing. Exclude.
): ):
asyncio.create_task( asyncio.create_task(
litellm.cache.async_add_cache_pipeline(result, **kwargs) litellm.cache.async_add_cache_pipeline(result, **new_kwargs)
) )
elif isinstance(litellm.cache.cache, S3Cache): elif isinstance(litellm.cache.cache, S3Cache):
threading.Thread( threading.Thread(
target=litellm.cache.add_cache, target=litellm.cache.add_cache,
args=(result,), args=(result,),
kwargs=kwargs, kwargs=new_kwargs,
).start() ).start()
else: else:
asyncio.create_task( asyncio.create_task(
litellm.cache.async_add_cache( litellm.cache.async_add_cache(
result.model_dump_json(), **kwargs result.model_dump_json(), **new_kwargs
) )
) )
else: else:
asyncio.create_task(litellm.cache.async_add_cache(result, **kwargs)) asyncio.create_task(litellm.cache.async_add_cache(result, **new_kwargs))
def sync_set_cache( def sync_set_cache(
self, self,
@ -706,16 +736,20 @@ class LLMCachingHandler:
""" """
Sync internal method to add the result to the cache Sync internal method to add the result to the cache
""" """
kwargs.update( new_kwargs = kwargs.copy()
convert_args_to_kwargs(result, self.original_function, kwargs, args) new_kwargs.update(
convert_args_to_kwargs(
self.original_function,
args,
)
) )
if litellm.cache is None: if litellm.cache is None:
return return
if self._should_store_result_in_cache( if self._should_store_result_in_cache(
original_function=self.original_function, kwargs=kwargs original_function=self.original_function, kwargs=new_kwargs
): ):
litellm.cache.add_cache(result, **kwargs) litellm.cache.add_cache(result, **new_kwargs)
return return
@ -865,9 +899,7 @@ class LLMCachingHandler:
def convert_args_to_kwargs( def convert_args_to_kwargs(
result: Any,
original_function: Callable, original_function: Callable,
kwargs: Dict[str, Any],
args: Optional[Tuple[Any, ...]] = None, args: Optional[Tuple[Any, ...]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# Get the signature of the original function # Get the signature of the original function

View file

@ -24,7 +24,6 @@ class DiskCache(BaseCache):
self.disk_cache = dc.Cache(disk_cache_dir) self.disk_cache = dc.Cache(disk_cache_dir)
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, **kwargs):
print_verbose("DiskCache: set_cache")
if "ttl" in kwargs: if "ttl" in kwargs:
self.disk_cache.set(key, value, expire=kwargs["ttl"]) self.disk_cache.set(key, value, expire=kwargs["ttl"])
else: else:
@ -33,10 +32,10 @@ class DiskCache(BaseCache):
async def async_set_cache(self, key, value, **kwargs): async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs) self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, ttl=None): async def async_set_cache_pipeline(self, cache_list, **kwargs):
for cache_key, cache_value in cache_list: for cache_key, cache_value in cache_list:
if ttl is not None: if "ttl" in kwargs:
self.set_cache(key=cache_key, value=cache_value, ttl=ttl) self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"])
else: else:
self.set_cache(key=cache_key, value=cache_value) self.set_cache(key=cache_key, value=cache_value)

View file

@ -314,7 +314,8 @@ class DualCache(BaseCache):
f"LiteLLM Cache: Excepton async add_cache: {str(e)}" f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
) )
async def async_batch_set_cache( # async_batch_set_cache
async def async_set_cache_pipeline(
self, cache_list: list, local_only: bool = False, **kwargs self, cache_list: list, local_only: bool = False, **kwargs
): ):
""" """

View file

@ -9,6 +9,7 @@ Has 4 methods:
""" """
import ast import ast
import asyncio
import json import json
from typing import Any from typing import Any
@ -422,3 +423,9 @@ class QdrantSemanticCache(BaseCache):
async def _collection_info(self): async def _collection_info(self):
return self.collection_info return self.collection_info
async def async_set_cache_pipeline(self, cache_list, **kwargs):
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)

View file

@ -404,7 +404,7 @@ class RedisCache(BaseCache):
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs), parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
) )
) )
return results return None
except Exception as e: except Exception as e:
## LOGGING ## ## LOGGING ##
end_time = time.time() end_time = time.time()

View file

@ -9,6 +9,7 @@ Has 4 methods:
""" """
import ast import ast
import asyncio
import json import json
from typing import Any from typing import Any
@ -331,3 +332,9 @@ class RedisSemanticCache(BaseCache):
async def _index_info(self): async def _index_info(self):
return await self.index.ainfo() return await self.index.ainfo()
async def async_set_cache_pipeline(self, cache_list, **kwargs):
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)

View file

@ -10,6 +10,7 @@ Has 4 methods:
""" """
import ast import ast
import asyncio
import json import json
from typing import Any, Optional from typing import Any, Optional
@ -153,3 +154,9 @@ class S3Cache(BaseCache):
async def disconnect(self): async def disconnect(self):
pass pass
async def async_set_cache_pipeline(self, cache_list, **kwargs):
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)

View file

@ -423,7 +423,7 @@ class SlackAlerting(CustomBatchLogger):
latency_cache_keys = [(key, 0) for key in latency_keys] latency_cache_keys = [(key, 0) for key in latency_keys]
failed_request_cache_keys = [(key, 0) for key in failed_request_keys] failed_request_cache_keys = [(key, 0) for key in failed_request_keys]
combined_metrics_cache_keys = latency_cache_keys + failed_request_cache_keys combined_metrics_cache_keys = latency_cache_keys + failed_request_cache_keys
await self.internal_usage_cache.async_batch_set_cache( await self.internal_usage_cache.async_set_cache_pipeline(
cache_list=combined_metrics_cache_keys cache_list=combined_metrics_cache_keys
) )

View file

@ -3,8 +3,9 @@
import copy import copy
import os import os
import traceback import traceback
import types
from collections.abc import MutableMapping, MutableSequence, MutableSet from collections.abc import MutableMapping, MutableSequence, MutableSet
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional, cast
from packaging.version import Version from packaging.version import Version
from pydantic import BaseModel from pydantic import BaseModel
@ -355,17 +356,28 @@ class LangFuseLogger:
) )
) )
def _prepare_metadata(self, metadata) -> Any: def is_base_type(self, value: Any) -> bool:
# Check if the value is of a base type
base_types = (int, float, str, bool, list, dict, tuple)
return isinstance(value, base_types)
def _prepare_metadata(self, metadata: Optional[dict]) -> Any:
try: try:
return copy.deepcopy(metadata) # Avoid modifying the original metadata if metadata is None:
except (TypeError, copy.Error) as e: return None
verbose_logger.warning(f"Langfuse Layer Error - {e}")
# Filter out function types from the metadata
sanitized_metadata = {k: v for k, v in metadata.items() if not callable(v)}
return copy.deepcopy(sanitized_metadata)
except Exception as e:
verbose_logger.debug(f"Langfuse Layer Error - {e}, metadata: {metadata}")
new_metadata: Dict[str, Any] = {} new_metadata: Dict[str, Any] = {}
# if metadata is not a MutableMapping, return an empty dict since we can't call items() on it # if metadata is not a MutableMapping, return an empty dict since we can't call items() on it
if not isinstance(metadata, MutableMapping): if not isinstance(metadata, MutableMapping):
verbose_logger.warning( verbose_logger.debug(
"Langfuse Layer Logging - metadata is not a MutableMapping, returning empty dict" "Langfuse Layer Logging - metadata is not a MutableMapping, returning empty dict"
) )
return new_metadata return new_metadata
@ -373,25 +385,40 @@ class LangFuseLogger:
for key, value in metadata.items(): for key, value in metadata.items():
try: try:
if isinstance(value, MutableMapping): if isinstance(value, MutableMapping):
new_metadata[key] = self._prepare_metadata(value) new_metadata[key] = self._prepare_metadata(cast(dict, value))
elif isinstance(value, (MutableSequence, MutableSet)): elif isinstance(value, MutableSequence):
new_metadata[key] = type(value)( # For lists or other mutable sequences
*( new_metadata[key] = list(
( (
self._prepare_metadata(v) self._prepare_metadata(cast(dict, v))
if isinstance(v, MutableMapping) if isinstance(v, MutableMapping)
else copy.deepcopy(v) else copy.deepcopy(v)
) )
for v in value for v in value
) )
elif isinstance(value, MutableSet):
# For sets specifically, create a new set by passing an iterable
new_metadata[key] = set(
(
self._prepare_metadata(cast(dict, v))
if isinstance(v, MutableMapping)
else copy.deepcopy(v)
)
for v in value
) )
elif isinstance(value, BaseModel): elif isinstance(value, BaseModel):
new_metadata[key] = value.model_dump() new_metadata[key] = value.model_dump()
elif self.is_base_type(value):
new_metadata[key] = value
else: else:
new_metadata[key] = copy.deepcopy(value) verbose_logger.debug(
f"Langfuse Layer Error - Unsupported metadata type: {type(value)} for key: {key}"
)
continue
except (TypeError, copy.Error): except (TypeError, copy.Error):
verbose_logger.warning( verbose_logger.debug(
f"Langfuse Layer Error - Couldn't copy metadata key: {key} - {traceback.format_exc()}" f"Langfuse Layer Error - Couldn't copy metadata key: {key}, type of key: {type(key)}, type of value: {type(value)} - {traceback.format_exc()}"
) )
return new_metadata return new_metadata

View file

@ -2774,11 +2774,6 @@ def get_standard_logging_object_payload(
metadata=metadata metadata=metadata
) )
if litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(**kwargs)
else:
cache_key = None
saved_cache_cost: float = 0.0 saved_cache_cost: float = 0.0
if cache_hit is True: if cache_hit is True:
@ -2820,7 +2815,7 @@ def get_standard_logging_object_payload(
completionStartTime=completion_start_time_float, completionStartTime=completion_start_time_float,
model=kwargs.get("model", "") or "", model=kwargs.get("model", "") or "",
metadata=clean_metadata, metadata=clean_metadata,
cache_key=cache_key, cache_key=clean_hidden_params["cache_key"],
response_cost=response_cost, response_cost=response_cost,
total_tokens=usage.total_tokens, total_tokens=usage.total_tokens,
prompt_tokens=usage.prompt_tokens, prompt_tokens=usage.prompt_tokens,

View file

@ -1,12 +1,80 @@
model_list: model_list:
- model_name: "*" - model_name: "*"
litellm_params: litellm_params:
model: "*" model: claude-3-5-sonnet-20240620
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: claude-3-5-sonnet-aihubmix
litellm_params:
model: openai/claude-3-5-sonnet-20240620
input_cost_per_token: 0.000003 # 3$/M
output_cost_per_token: 0.000015 # 15$/M
api_base: "https://exampleopenaiendpoint-production.up.railway.app"
api_key: my-fake-key
- model_name: fake-openai-endpoint-2
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
stream_timeout: 0.001
timeout: 1
rpm: 1
- model_name: fake-openai-endpoint
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
## bedrock chat completions
- model_name: "*anthropic.claude*"
litellm_params:
model: bedrock/*anthropic.claude*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
guardrailConfig:
"guardrailIdentifier": "h4dsqwhp6j66"
"guardrailVersion": "2"
"trace": "enabled"
## bedrock embeddings
- model_name: "*amazon.titan-embed-*"
litellm_params:
model: bedrock/amazon.titan-embed-*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: "*cohere.embed-*"
litellm_params:
model: bedrock/cohere.embed-*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: "bedrock/*"
litellm_params:
model: bedrock/*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: gpt-4
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
rpm: 480
timeout: 300
stream_timeout: 60
litellm_settings: litellm_settings:
fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }] fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
callbacks: ["otel", "prometheus"] callbacks: ["otel", "prometheus"]
default_redis_batch_cache_expiry: 10 default_redis_batch_cache_expiry: 10
# default_team_settings:
# - team_id: "dbe2f686-a686-4896-864a-4c3924458709"
# success_callback: ["langfuse"]
# langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1
# langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1
# litellm_settings: # litellm_settings:
# cache: True # cache: True

View file

@ -1308,7 +1308,7 @@ async def update_cache( # noqa: PLR0915
await _update_team_cache() await _update_team_cache()
asyncio.create_task( asyncio.create_task(
user_api_key_cache.async_batch_set_cache( user_api_key_cache.async_set_cache_pipeline(
cache_list=values_to_update_in_cache, cache_list=values_to_update_in_cache,
ttl=60, ttl=60,
litellm_parent_otel_span=parent_otel_span, litellm_parent_otel_span=parent_otel_span,
@ -2978,7 +2978,7 @@ class ProxyStartupEvent:
if ( if (
proxy_logging_obj is not None proxy_logging_obj is not None
and proxy_logging_obj.slack_alerting_instance is not None and proxy_logging_obj.slack_alerting_instance.alerting is not None
and prisma_client is not None and prisma_client is not None
): ):
print("Alerting: Initializing Weekly/Monthly Spend Reports") # noqa print("Alerting: Initializing Weekly/Monthly Spend Reports") # noqa

View file

@ -175,7 +175,7 @@ class InternalUsageCache:
local_only: bool = False, local_only: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
return await self.dual_cache.async_batch_set_cache( return await self.dual_cache.async_set_cache_pipeline(
cache_list=cache_list, cache_list=cache_list,
local_only=local_only, local_only=local_only,
litellm_parent_otel_span=litellm_parent_otel_span, litellm_parent_otel_span=litellm_parent_otel_span,

View file

@ -339,11 +339,7 @@ class Router:
cache_config: Dict[str, Any] = {} cache_config: Dict[str, Any] = {}
self.client_ttl = client_ttl self.client_ttl = client_ttl
if redis_url is not None or ( if redis_url is not None or (redis_host is not None and redis_port is not None):
redis_host is not None
and redis_port is not None
and redis_password is not None
):
cache_type = "redis" cache_type = "redis"
if redis_url is not None: if redis_url is not None:

View file

@ -796,7 +796,7 @@ def client(original_function): # noqa: PLR0915
and kwargs.get("_arealtime", False) is not True and kwargs.get("_arealtime", False) is not True
): # allow users to control returning cached responses from the completion function ): # allow users to control returning cached responses from the completion function
# checking cache # checking cache
print_verbose("INSIDE CHECKING CACHE") verbose_logger.debug("INSIDE CHECKING SYNC CACHE")
caching_handler_response: CachingHandlerResponse = ( caching_handler_response: CachingHandlerResponse = (
_llm_caching_handler._sync_get_cache( _llm_caching_handler._sync_get_cache(
model=model or "", model=model or "",
@ -808,6 +808,7 @@ def client(original_function): # noqa: PLR0915
args=args, args=args,
) )
) )
if caching_handler_response.cached_result is not None: if caching_handler_response.cached_result is not None:
return caching_handler_response.cached_result return caching_handler_response.cached_result

View file

@ -0,0 +1,223 @@
from abc import ABC, abstractmethod
from litellm.caching import LiteLLMCacheType
import os
import sys
import time
import traceback
import uuid
from dotenv import load_dotenv
from test_rerank import assert_response_shape
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
import hashlib
import random
import pytest
import litellm
from litellm.caching import Cache
from litellm import completion, embedding
class LLMCachingUnitTests(ABC):
@abstractmethod
def get_cache_type(self) -> LiteLLMCacheType:
pass
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_cache_completion(self, sync_mode):
litellm._turn_on_debug()
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
messages = [
{
"role": "user",
"content": f"write a one sentence poem about: {random_number}",
}
]
cache_type = self.get_cache_type()
litellm.cache = Cache(
type=cache_type,
)
if sync_mode:
response1 = completion(
"gpt-3.5-turbo",
messages=messages,
caching=True,
max_tokens=20,
mock_response="This number is so great!",
)
else:
response1 = await litellm.acompletion(
"gpt-3.5-turbo",
messages=messages,
caching=True,
max_tokens=20,
mock_response="This number is so great!",
)
# response2 is mocked to a different response from response1,
# but the completion from the cache should be used instead of the mock
# response since the input is the same as response1
await asyncio.sleep(0.5)
if sync_mode:
response2 = completion(
"gpt-3.5-turbo",
messages=messages,
caching=True,
max_tokens=20,
mock_response="This number is great!",
)
else:
response2 = await litellm.acompletion(
"gpt-3.5-turbo",
messages=messages,
caching=True,
max_tokens=20,
mock_response="This number is great!",
)
if (
response1["choices"][0]["message"]["content"]
!= response2["choices"][0]["message"]["content"]
): # 1 and 2 should be the same
# 1&2 have the exact same input params. This MUST Be a CACHE HIT
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(
f"Error occurred: response1 - {response1['choices'][0]['message']['content']} != response2 - {response2['choices'][0]['message']['content']}"
)
# Since the parameters are not the same as response1, response3 should actually
# be the mock response
if sync_mode:
response3 = completion(
"gpt-3.5-turbo",
messages=messages,
caching=True,
temperature=0.5,
mock_response="This number is awful!",
)
else:
response3 = await litellm.acompletion(
"gpt-3.5-turbo",
messages=messages,
caching=True,
temperature=0.5,
mock_response="This number is awful!",
)
print("\nresponse 1", response1)
print("\nresponse 2", response2)
print("\nresponse 3", response3)
# print("\nresponse 4", response4)
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
# 1 & 2 should be exactly the same
# 1 & 3 should be different, since input params are diff
if (
response1["choices"][0]["message"]["content"]
== response3["choices"][0]["message"]["content"]
):
# if input params like max_tokens, temperature are diff it should NOT be a cache hit
print(f"response1: {response1}")
print(f"response3: {response3}")
pytest.fail(
f"Response 1 == response 3. Same model, diff params shoudl not cache Error"
f" occurred:"
)
assert response1.id == response2.id
assert response1.created == response2.created
assert (
response1.choices[0].message.content == response2.choices[0].message.content
)
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_disk_cache_embedding(self, sync_mode):
litellm._turn_on_debug()
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
input = [f"hello {random_number}"]
litellm.cache = Cache(
type="disk",
)
if sync_mode:
response1 = embedding(
"openai/text-embedding-ada-002",
input=input,
caching=True,
)
else:
response1 = await litellm.aembedding(
"openai/text-embedding-ada-002",
input=input,
caching=True,
)
# response2 is mocked to a different response from response1,
# but the completion from the cache should be used instead of the mock
# response since the input is the same as response1
await asyncio.sleep(0.5)
if sync_mode:
response2 = embedding(
"openai/text-embedding-ada-002",
input=input,
caching=True,
)
else:
response2 = await litellm.aembedding(
"openai/text-embedding-ada-002",
input=input,
caching=True,
)
if response2._hidden_params["cache_hit"] is not True:
pytest.fail("Cache hit should be True")
# Since the parameters are not the same as response1, response3 should actually
# be the mock response
if sync_mode:
response3 = embedding(
"openai/text-embedding-ada-002",
input=input,
user="charlie",
caching=True,
)
else:
response3 = await litellm.aembedding(
"openai/text-embedding-ada-002",
input=input,
caching=True,
user="charlie",
)
print("\nresponse 1", response1)
print("\nresponse 2", response2)
print("\nresponse 3", response3)
# print("\nresponse 4", response4)
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
# 1 & 2 should be exactly the same
# 1 & 3 should be different, since input params are diff
if response3._hidden_params.get("cache_hit") is True:
pytest.fail("Cache hit should not be True")

View file

@ -438,7 +438,7 @@ async def test_send_daily_reports_ignores_zero_values():
slack_alerting.internal_usage_cache.async_batch_get_cache = AsyncMock( slack_alerting.internal_usage_cache.async_batch_get_cache = AsyncMock(
return_value=[None, 0, 10, 0, 0, None] return_value=[None, 0, 10, 0, 0, None]
) )
slack_alerting.internal_usage_cache.async_batch_set_cache = AsyncMock() slack_alerting.internal_usage_cache.async_set_cache_pipeline = AsyncMock()
router.get_model_info.side_effect = lambda x: {"litellm_params": {"model": x}} router.get_model_info.side_effect = lambda x: {"litellm_params": {"model": x}}

View file

@ -1103,81 +1103,6 @@ async def test_redis_cache_acompletion_stream_bedrock():
raise e raise e
def test_disk_cache_completion():
litellm.set_verbose = False
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
messages = [
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
]
litellm.cache = Cache(
type="disk",
)
response1 = completion(
model="gpt-3.5-turbo",
messages=messages,
caching=True,
max_tokens=20,
mock_response="This number is so great!",
)
# response2 is mocked to a different response from response1,
# but the completion from the cache should be used instead of the mock
# response since the input is the same as response1
response2 = completion(
model="gpt-3.5-turbo",
messages=messages,
caching=True,
max_tokens=20,
mock_response="This number is awful!",
)
# Since the parameters are not the same as response1, response3 should actually
# be the mock response
response3 = completion(
model="gpt-3.5-turbo",
messages=messages,
caching=True,
temperature=0.5,
mock_response="This number is awful!",
)
print("\nresponse 1", response1)
print("\nresponse 2", response2)
print("\nresponse 3", response3)
# print("\nresponse 4", response4)
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
# 1 & 2 should be exactly the same
# 1 & 3 should be different, since input params are diff
if (
response1["choices"][0]["message"]["content"]
!= response2["choices"][0]["message"]["content"]
): # 1 and 2 should be the same
# 1&2 have the exact same input params. This MUST Be a CACHE HIT
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(f"Error occurred:")
if (
response1["choices"][0]["message"]["content"]
== response3["choices"][0]["message"]["content"]
):
# if input params like max_tokens, temperature are diff it should NOT be a cache hit
print(f"response1: {response1}")
print(f"response3: {response3}")
pytest.fail(
f"Response 1 == response 3. Same model, diff params shoudl not cache Error"
f" occurred:"
)
assert response1.id == response2.id
assert response1.created == response2.created
assert response1.choices[0].message.content == response2.choices[0].message.content
# @pytest.mark.skip(reason="AWS Suspended Account") # @pytest.mark.skip(reason="AWS Suspended Account")
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -0,0 +1,11 @@
from cache_unit_tests import LLMCachingUnitTests
from litellm.caching import LiteLLMCacheType
class TestDiskCacheUnitTests(LLMCachingUnitTests):
def get_cache_type(self) -> LiteLLMCacheType:
return LiteLLMCacheType.DISK
# if __name__ == "__main__":
# pytest.main([__file__, "-v", "-s"])

View file

@ -146,7 +146,7 @@ async def test_dual_cache_batch_operations(is_async):
# Set values # Set values
if is_async: if is_async:
await dual_cache.async_batch_set_cache(cache_list) await dual_cache.async_set_cache_pipeline(cache_list)
else: else:
for key, value in cache_list: for key, value in cache_list:
dual_cache.set_cache(key, value) dual_cache.set_cache(key, value)

View file

@ -212,26 +212,48 @@ def test_get_langfuse_logger_for_request_with_cached_logger():
assert result == cached_logger assert result == cached_logger
mock_cache.get_cache.assert_called_once() mock_cache.get_cache.assert_called_once()
@pytest.mark.parametrize("metadata", [
{'a': 1, 'b': 2, 'c': 3}, @pytest.mark.parametrize(
{'a': {'nested_a': 1}, 'b': {'nested_b': 2}}, "metadata, expected_metadata",
{'a': [1, 2, 3], 'b': {4, 5, 6}}, [
{'a': (1, 2), 'b': frozenset([3, 4]), 'c': {'d': [5, 6]}}, ({"a": 1, "b": 2, "c": 3}, {"a": 1, "b": 2, "c": 3}),
{'lock': threading.Lock()}, (
{'func': lambda x: x + 1}, {"a": {"nested_a": 1}, "b": {"nested_b": 2}},
{"a": {"nested_a": 1}, "b": {"nested_b": 2}},
),
({"a": [1, 2, 3], "b": {4, 5, 6}}, {"a": [1, 2, 3], "b": {4, 5, 6}}),
(
{"a": (1, 2), "b": frozenset([3, 4]), "c": {"d": [5, 6]}},
{"a": (1, 2), "b": frozenset([3, 4]), "c": {"d": [5, 6]}},
),
({"lock": threading.Lock()}, {}),
({"func": lambda x: x + 1}, {}),
(
{ {
'int': 42, "int": 42,
'str': 'hello', "str": "hello",
'list': [1, 2, 3], "list": [1, 2, 3],
'set': {4, 5}, "set": {4, 5},
'dict': {'nested': 'value'}, "dict": {"nested": "value"},
'non_copyable': threading.Lock(), "non_copyable": threading.Lock(),
'function': print "function": print,
}, },
['list', 'not', 'a', 'dict'], {
{'timestamp': datetime.now()}, "int": 42,
{}, "str": "hello",
None, "list": [1, 2, 3],
]) "set": {4, 5},
def test_langfuse_logger_prepare_metadata(metadata): "dict": {"nested": "value"},
global_langfuse_logger._prepare_metadata(metadata) },
),
(
{"list": ["list", "not", "a", "dict"]},
{"list": ["list", "not", "a", "dict"]},
),
({}, {}),
(None, None),
],
)
def test_langfuse_logger_prepare_metadata(metadata, expected_metadata):
result = global_langfuse_logger._prepare_metadata(metadata)
assert result == expected_metadata