forked from phoenix/litellm-mirror
LiteLLM Minor Fixes & Improvements (11/12/2024) (#6705)
* fix(caching): convert arg to equivalent kwargs in llm caching handler prevent unexpected errors * fix(caching_handler.py): don't pass args to caching * fix(caching): remove all *args from caching.py * fix(caching): consistent function signatures + abc method * test(caching_unit_tests.py): add unit tests for llm caching ensures coverage for common caching scenarios across different implementations * refactor(litellm_logging.py): move to using cache key from hidden params instead of regenerating one * fix(router.py): drop redis password requirement * fix(proxy_server.py): fix faulty slack alerting check * fix(langfuse.py): avoid copying functions/thread lock objects in metadata fixes metadata copy error when parent otel span in metadata * test: update test
This commit is contained in:
parent
d39fd60801
commit
9160d80fa5
23 changed files with 525 additions and 204 deletions
|
@ -8,6 +8,7 @@ Has 4 methods:
|
|||
- async_get_cache
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -18,7 +19,7 @@ else:
|
|||
Span = Any
|
||||
|
||||
|
||||
class BaseCache:
|
||||
class BaseCache(ABC):
|
||||
def __init__(self, default_ttl: int = 60):
|
||||
self.default_ttl = default_ttl
|
||||
|
||||
|
@ -37,6 +38,10 @@ class BaseCache:
|
|||
async def async_set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
pass
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -233,19 +233,18 @@ class Cache:
|
|||
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
||||
self.cache.namespace = self.namespace
|
||||
|
||||
def get_cache_key(self, *args, **kwargs) -> str:
|
||||
def get_cache_key(self, **kwargs) -> str:
|
||||
"""
|
||||
Get the cache key for the given arguments.
|
||||
|
||||
Args:
|
||||
*args: args to litellm.completion() or embedding()
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
str: The cache key generated from the arguments, or None if no cache key could be generated.
|
||||
"""
|
||||
cache_key = ""
|
||||
verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
|
||||
# verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
|
||||
|
||||
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
|
||||
if preset_cache_key is not None:
|
||||
|
@ -521,7 +520,7 @@ class Cache:
|
|||
return cached_response
|
||||
return cached_result
|
||||
|
||||
def get_cache(self, *args, **kwargs):
|
||||
def get_cache(self, **kwargs):
|
||||
"""
|
||||
Retrieves the cached result for the given arguments.
|
||||
|
||||
|
@ -533,13 +532,13 @@ class Cache:
|
|||
The cached result if it exists, otherwise None.
|
||||
"""
|
||||
try: # never block execution
|
||||
if self.should_use_cache(*args, **kwargs) is not True:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
messages = kwargs.get("messages", [])
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(*args, **kwargs)
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
cache_control_args = kwargs.get("cache", {})
|
||||
max_age = cache_control_args.get(
|
||||
|
@ -553,29 +552,28 @@ class Cache:
|
|||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def async_get_cache(self, *args, **kwargs):
|
||||
async def async_get_cache(self, **kwargs):
|
||||
"""
|
||||
Async get cache implementation.
|
||||
|
||||
Used for embedding calls in async wrapper
|
||||
"""
|
||||
|
||||
try: # never block execution
|
||||
if self.should_use_cache(*args, **kwargs) is not True:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
|
||||
kwargs.get("messages", [])
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(*args, **kwargs)
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
cache_control_args = kwargs.get("cache", {})
|
||||
max_age = cache_control_args.get(
|
||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
)
|
||||
cached_result = await self.cache.async_get_cache(
|
||||
cache_key, *args, **kwargs
|
||||
)
|
||||
cached_result = await self.cache.async_get_cache(cache_key, **kwargs)
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
|
@ -583,7 +581,7 @@ class Cache:
|
|||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
def _add_cache_logic(self, result, *args, **kwargs):
|
||||
def _add_cache_logic(self, result, **kwargs):
|
||||
"""
|
||||
Common implementation across sync + async add_cache functions
|
||||
"""
|
||||
|
@ -591,7 +589,7 @@ class Cache:
|
|||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(*args, **kwargs)
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
if isinstance(result, BaseModel):
|
||||
result = result.model_dump_json()
|
||||
|
@ -613,7 +611,7 @@ class Cache:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def add_cache(self, result, *args, **kwargs):
|
||||
def add_cache(self, result, **kwargs):
|
||||
"""
|
||||
Adds a result to the cache.
|
||||
|
||||
|
@ -625,41 +623,42 @@ class Cache:
|
|||
None
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(*args, **kwargs) is not True:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, *args, **kwargs
|
||||
result=result, **kwargs
|
||||
)
|
||||
self.cache.set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
async def async_add_cache(self, result, *args, **kwargs):
|
||||
async def async_add_cache(self, result, **kwargs):
|
||||
"""
|
||||
Async implementation of add_cache
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(*args, **kwargs) is not True:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
if self.type == "redis" and self.redis_flush_size is not None:
|
||||
# high traffic - fill in results in memory and then flush
|
||||
await self.batch_cache_write(result, *args, **kwargs)
|
||||
await self.batch_cache_write(result, **kwargs)
|
||||
else:
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, *args, **kwargs
|
||||
result=result, **kwargs
|
||||
)
|
||||
|
||||
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
async def async_add_cache_pipeline(self, result, *args, **kwargs):
|
||||
async def async_add_cache_pipeline(self, result, **kwargs):
|
||||
"""
|
||||
Async implementation of add_cache for Embedding calls
|
||||
|
||||
Does a bulk write, to prevent using too many clients
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(*args, **kwargs) is not True:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
|
||||
# set default ttl if not set
|
||||
|
@ -668,29 +667,27 @@ class Cache:
|
|||
|
||||
cache_list = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i})
|
||||
preset_cache_key = self.get_cache_key(**{**kwargs, "input": i})
|
||||
kwargs["cache_key"] = preset_cache_key
|
||||
embedding_response = result.data[idx]
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=embedding_response,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
cache_list.append((cache_key, cached_data))
|
||||
async_set_cache_pipeline = getattr(
|
||||
self.cache, "async_set_cache_pipeline", None
|
||||
)
|
||||
if async_set_cache_pipeline:
|
||||
await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
|
||||
else:
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs)
|
||||
# if async_set_cache_pipeline:
|
||||
# await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
|
||||
# else:
|
||||
# tasks = []
|
||||
# for val in cache_list:
|
||||
# tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
|
||||
# await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
def should_use_cache(self, *args, **kwargs):
|
||||
def should_use_cache(self, **kwargs):
|
||||
"""
|
||||
Returns true if we should use the cache for LLM API calls
|
||||
|
||||
|
@ -708,10 +705,8 @@ class Cache:
|
|||
return True
|
||||
return False
|
||||
|
||||
async def batch_cache_write(self, result, *args, **kwargs):
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, *args, **kwargs
|
||||
)
|
||||
async def batch_cache_write(self, result, **kwargs):
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
|
||||
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
|
||||
|
||||
async def ping(self):
|
||||
|
|
|
@ -137,7 +137,7 @@ class LLMCachingHandler:
|
|||
if litellm.cache is not None and self._is_call_type_supported_by_cache(
|
||||
original_function=original_function
|
||||
):
|
||||
print_verbose("Checking Cache")
|
||||
verbose_logger.debug("Checking Cache")
|
||||
cached_result = await self._retrieve_from_cache(
|
||||
call_type=call_type,
|
||||
kwargs=kwargs,
|
||||
|
@ -145,7 +145,7 @@ class LLMCachingHandler:
|
|||
)
|
||||
|
||||
if cached_result is not None and not isinstance(cached_result, list):
|
||||
print_verbose("Cache Hit!")
|
||||
verbose_logger.debug("Cache Hit!")
|
||||
cache_hit = True
|
||||
end_time = datetime.datetime.now()
|
||||
model, _, _, _ = litellm.get_llm_provider(
|
||||
|
@ -215,6 +215,7 @@ class LLMCachingHandler:
|
|||
final_embedding_cached_response=final_embedding_cached_response,
|
||||
embedding_all_elements_cache_hit=embedding_all_elements_cache_hit,
|
||||
)
|
||||
verbose_logger.debug(f"CACHE RESULT: {cached_result}")
|
||||
return CachingHandlerResponse(
|
||||
cached_result=cached_result,
|
||||
final_embedding_cached_response=final_embedding_cached_response,
|
||||
|
@ -233,12 +234,19 @@ class LLMCachingHandler:
|
|||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
args = args or ()
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.update(
|
||||
convert_args_to_kwargs(
|
||||
self.original_function,
|
||||
args,
|
||||
)
|
||||
)
|
||||
cached_result: Optional[Any] = None
|
||||
if litellm.cache is not None and self._is_call_type_supported_by_cache(
|
||||
original_function=original_function
|
||||
):
|
||||
print_verbose("Checking Cache")
|
||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
cached_result = litellm.cache.get_cache(**new_kwargs)
|
||||
if cached_result is not None:
|
||||
if "detail" in cached_result:
|
||||
# implies an error occurred
|
||||
|
@ -475,14 +483,21 @@ class LLMCachingHandler:
|
|||
if litellm.cache is None:
|
||||
return None
|
||||
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.update(
|
||||
convert_args_to_kwargs(
|
||||
self.original_function,
|
||||
args,
|
||||
)
|
||||
)
|
||||
cached_result: Optional[Any] = None
|
||||
if call_type == CallTypes.aembedding.value and isinstance(
|
||||
kwargs["input"], list
|
||||
new_kwargs["input"], list
|
||||
):
|
||||
tasks = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
for idx, i in enumerate(new_kwargs["input"]):
|
||||
preset_cache_key = litellm.cache.get_cache_key(
|
||||
*args, **{**kwargs, "input": i}
|
||||
**{**new_kwargs, "input": i}
|
||||
)
|
||||
tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key))
|
||||
cached_result = await asyncio.gather(*tasks)
|
||||
|
@ -493,9 +508,9 @@ class LLMCachingHandler:
|
|||
cached_result = None
|
||||
else:
|
||||
if litellm.cache._supports_async() is True:
|
||||
cached_result = await litellm.cache.async_get_cache(*args, **kwargs)
|
||||
cached_result = await litellm.cache.async_get_cache(**new_kwargs)
|
||||
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||
cached_result = litellm.cache.get_cache(**new_kwargs)
|
||||
return cached_result
|
||||
|
||||
def _convert_cached_result_to_model_response(
|
||||
|
@ -580,6 +595,7 @@ class LLMCachingHandler:
|
|||
model_response_object=EmbeddingResponse(),
|
||||
response_type="embedding",
|
||||
)
|
||||
|
||||
elif (
|
||||
call_type == CallTypes.arerank.value or call_type == CallTypes.rerank.value
|
||||
) and isinstance(cached_result, dict):
|
||||
|
@ -603,6 +619,13 @@ class LLMCachingHandler:
|
|||
response_type="audio_transcription",
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(cached_result, "_hidden_params")
|
||||
and cached_result._hidden_params is not None
|
||||
and isinstance(cached_result._hidden_params, dict)
|
||||
):
|
||||
cached_result._hidden_params["cache_hit"] = True
|
||||
return cached_result
|
||||
|
||||
def _convert_cached_stream_response(
|
||||
|
@ -658,12 +681,19 @@ class LLMCachingHandler:
|
|||
Raises:
|
||||
None
|
||||
"""
|
||||
kwargs.update(convert_args_to_kwargs(result, original_function, kwargs, args))
|
||||
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.update(
|
||||
convert_args_to_kwargs(
|
||||
original_function,
|
||||
args,
|
||||
)
|
||||
)
|
||||
if litellm.cache is None:
|
||||
return
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if self._should_store_result_in_cache(
|
||||
original_function=original_function, kwargs=kwargs
|
||||
original_function=original_function, kwargs=new_kwargs
|
||||
):
|
||||
if (
|
||||
isinstance(result, litellm.ModelResponse)
|
||||
|
@ -673,29 +703,29 @@ class LLMCachingHandler:
|
|||
):
|
||||
if (
|
||||
isinstance(result, EmbeddingResponse)
|
||||
and isinstance(kwargs["input"], list)
|
||||
and isinstance(new_kwargs["input"], list)
|
||||
and litellm.cache is not None
|
||||
and not isinstance(
|
||||
litellm.cache.cache, S3Cache
|
||||
) # s3 doesn't support bulk writing. Exclude.
|
||||
):
|
||||
asyncio.create_task(
|
||||
litellm.cache.async_add_cache_pipeline(result, **kwargs)
|
||||
litellm.cache.async_add_cache_pipeline(result, **new_kwargs)
|
||||
)
|
||||
elif isinstance(litellm.cache.cache, S3Cache):
|
||||
threading.Thread(
|
||||
target=litellm.cache.add_cache,
|
||||
args=(result,),
|
||||
kwargs=kwargs,
|
||||
kwargs=new_kwargs,
|
||||
).start()
|
||||
else:
|
||||
asyncio.create_task(
|
||||
litellm.cache.async_add_cache(
|
||||
result.model_dump_json(), **kwargs
|
||||
result.model_dump_json(), **new_kwargs
|
||||
)
|
||||
)
|
||||
else:
|
||||
asyncio.create_task(litellm.cache.async_add_cache(result, **kwargs))
|
||||
asyncio.create_task(litellm.cache.async_add_cache(result, **new_kwargs))
|
||||
|
||||
def sync_set_cache(
|
||||
self,
|
||||
|
@ -706,16 +736,20 @@ class LLMCachingHandler:
|
|||
"""
|
||||
Sync internal method to add the result to the cache
|
||||
"""
|
||||
kwargs.update(
|
||||
convert_args_to_kwargs(result, self.original_function, kwargs, args)
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.update(
|
||||
convert_args_to_kwargs(
|
||||
self.original_function,
|
||||
args,
|
||||
)
|
||||
)
|
||||
if litellm.cache is None:
|
||||
return
|
||||
|
||||
if self._should_store_result_in_cache(
|
||||
original_function=self.original_function, kwargs=kwargs
|
||||
original_function=self.original_function, kwargs=new_kwargs
|
||||
):
|
||||
litellm.cache.add_cache(result, **kwargs)
|
||||
litellm.cache.add_cache(result, **new_kwargs)
|
||||
|
||||
return
|
||||
|
||||
|
@ -865,9 +899,7 @@ class LLMCachingHandler:
|
|||
|
||||
|
||||
def convert_args_to_kwargs(
|
||||
result: Any,
|
||||
original_function: Callable,
|
||||
kwargs: Dict[str, Any],
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
# Get the signature of the original function
|
||||
|
|
|
@ -24,7 +24,6 @@ class DiskCache(BaseCache):
|
|||
self.disk_cache = dc.Cache(disk_cache_dir)
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
print_verbose("DiskCache: set_cache")
|
||||
if "ttl" in kwargs:
|
||||
self.disk_cache.set(key, value, expire=kwargs["ttl"])
|
||||
else:
|
||||
|
@ -33,10 +32,10 @@ class DiskCache(BaseCache):
|
|||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, ttl=None):
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
for cache_key, cache_value in cache_list:
|
||||
if ttl is not None:
|
||||
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
|
||||
if "ttl" in kwargs:
|
||||
self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"])
|
||||
else:
|
||||
self.set_cache(key=cache_key, value=cache_value)
|
||||
|
||||
|
|
|
@ -314,7 +314,8 @@ class DualCache(BaseCache):
|
|||
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
|
||||
)
|
||||
|
||||
async def async_batch_set_cache(
|
||||
# async_batch_set_cache
|
||||
async def async_set_cache_pipeline(
|
||||
self, cache_list: list, local_only: bool = False, **kwargs
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -9,6 +9,7 @@ Has 4 methods:
|
|||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
@ -422,3 +423,9 @@ class QdrantSemanticCache(BaseCache):
|
|||
|
||||
async def _collection_info(self):
|
||||
return self.collection_info
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
|
|
|
@ -404,7 +404,7 @@ class RedisCache(BaseCache):
|
|||
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
||||
)
|
||||
)
|
||||
return results
|
||||
return None
|
||||
except Exception as e:
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
|
|
|
@ -9,6 +9,7 @@ Has 4 methods:
|
|||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
@ -331,3 +332,9 @@ class RedisSemanticCache(BaseCache):
|
|||
|
||||
async def _index_info(self):
|
||||
return await self.index.ainfo()
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
|
|
|
@ -10,6 +10,7 @@ Has 4 methods:
|
|||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
|
@ -153,3 +154,9 @@ class S3Cache(BaseCache):
|
|||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
|
|
|
@ -423,7 +423,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
latency_cache_keys = [(key, 0) for key in latency_keys]
|
||||
failed_request_cache_keys = [(key, 0) for key in failed_request_keys]
|
||||
combined_metrics_cache_keys = latency_cache_keys + failed_request_cache_keys
|
||||
await self.internal_usage_cache.async_batch_set_cache(
|
||||
await self.internal_usage_cache.async_set_cache_pipeline(
|
||||
cache_list=combined_metrics_cache_keys
|
||||
)
|
||||
|
||||
|
|
|
@ -3,8 +3,9 @@
|
|||
import copy
|
||||
import os
|
||||
import traceback
|
||||
import types
|
||||
from collections.abc import MutableMapping, MutableSequence, MutableSet
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import BaseModel
|
||||
|
@ -355,17 +356,28 @@ class LangFuseLogger:
|
|||
)
|
||||
)
|
||||
|
||||
def _prepare_metadata(self, metadata) -> Any:
|
||||
def is_base_type(self, value: Any) -> bool:
|
||||
# Check if the value is of a base type
|
||||
base_types = (int, float, str, bool, list, dict, tuple)
|
||||
return isinstance(value, base_types)
|
||||
|
||||
def _prepare_metadata(self, metadata: Optional[dict]) -> Any:
|
||||
try:
|
||||
return copy.deepcopy(metadata) # Avoid modifying the original metadata
|
||||
except (TypeError, copy.Error) as e:
|
||||
verbose_logger.warning(f"Langfuse Layer Error - {e}")
|
||||
if metadata is None:
|
||||
return None
|
||||
|
||||
# Filter out function types from the metadata
|
||||
sanitized_metadata = {k: v for k, v in metadata.items() if not callable(v)}
|
||||
|
||||
return copy.deepcopy(sanitized_metadata)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Langfuse Layer Error - {e}, metadata: {metadata}")
|
||||
|
||||
new_metadata: Dict[str, Any] = {}
|
||||
|
||||
# if metadata is not a MutableMapping, return an empty dict since we can't call items() on it
|
||||
if not isinstance(metadata, MutableMapping):
|
||||
verbose_logger.warning(
|
||||
verbose_logger.debug(
|
||||
"Langfuse Layer Logging - metadata is not a MutableMapping, returning empty dict"
|
||||
)
|
||||
return new_metadata
|
||||
|
@ -373,25 +385,40 @@ class LangFuseLogger:
|
|||
for key, value in metadata.items():
|
||||
try:
|
||||
if isinstance(value, MutableMapping):
|
||||
new_metadata[key] = self._prepare_metadata(value)
|
||||
elif isinstance(value, (MutableSequence, MutableSet)):
|
||||
new_metadata[key] = type(value)(
|
||||
*(
|
||||
(
|
||||
self._prepare_metadata(v)
|
||||
if isinstance(v, MutableMapping)
|
||||
else copy.deepcopy(v)
|
||||
)
|
||||
for v in value
|
||||
new_metadata[key] = self._prepare_metadata(cast(dict, value))
|
||||
elif isinstance(value, MutableSequence):
|
||||
# For lists or other mutable sequences
|
||||
new_metadata[key] = list(
|
||||
(
|
||||
self._prepare_metadata(cast(dict, v))
|
||||
if isinstance(v, MutableMapping)
|
||||
else copy.deepcopy(v)
|
||||
)
|
||||
for v in value
|
||||
)
|
||||
elif isinstance(value, MutableSet):
|
||||
# For sets specifically, create a new set by passing an iterable
|
||||
new_metadata[key] = set(
|
||||
(
|
||||
self._prepare_metadata(cast(dict, v))
|
||||
if isinstance(v, MutableMapping)
|
||||
else copy.deepcopy(v)
|
||||
)
|
||||
for v in value
|
||||
)
|
||||
elif isinstance(value, BaseModel):
|
||||
new_metadata[key] = value.model_dump()
|
||||
elif self.is_base_type(value):
|
||||
new_metadata[key] = value
|
||||
else:
|
||||
new_metadata[key] = copy.deepcopy(value)
|
||||
verbose_logger.debug(
|
||||
f"Langfuse Layer Error - Unsupported metadata type: {type(value)} for key: {key}"
|
||||
)
|
||||
continue
|
||||
|
||||
except (TypeError, copy.Error):
|
||||
verbose_logger.warning(
|
||||
f"Langfuse Layer Error - Couldn't copy metadata key: {key} - {traceback.format_exc()}"
|
||||
verbose_logger.debug(
|
||||
f"Langfuse Layer Error - Couldn't copy metadata key: {key}, type of key: {type(key)}, type of value: {type(value)} - {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
return new_metadata
|
||||
|
|
|
@ -2774,11 +2774,6 @@ def get_standard_logging_object_payload(
|
|||
metadata=metadata
|
||||
)
|
||||
|
||||
if litellm.cache is not None:
|
||||
cache_key = litellm.cache.get_cache_key(**kwargs)
|
||||
else:
|
||||
cache_key = None
|
||||
|
||||
saved_cache_cost: float = 0.0
|
||||
if cache_hit is True:
|
||||
|
||||
|
@ -2820,7 +2815,7 @@ def get_standard_logging_object_payload(
|
|||
completionStartTime=completion_start_time_float,
|
||||
model=kwargs.get("model", "") or "",
|
||||
metadata=clean_metadata,
|
||||
cache_key=cache_key,
|
||||
cache_key=clean_hidden_params["cache_key"],
|
||||
response_cost=response_cost,
|
||||
total_tokens=usage.total_tokens,
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
|
|
|
@ -1,12 +1,80 @@
|
|||
model_list:
|
||||
- model_name: "*"
|
||||
litellm_params:
|
||||
model: "*"
|
||||
model: claude-3-5-sonnet-20240620
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
- model_name: claude-3-5-sonnet-aihubmix
|
||||
litellm_params:
|
||||
model: openai/claude-3-5-sonnet-20240620
|
||||
input_cost_per_token: 0.000003 # 3$/M
|
||||
output_cost_per_token: 0.000015 # 15$/M
|
||||
api_base: "https://exampleopenaiendpoint-production.up.railway.app"
|
||||
api_key: my-fake-key
|
||||
- model_name: fake-openai-endpoint-2
|
||||
litellm_params:
|
||||
model: openai/my-fake-model
|
||||
api_key: my-fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
stream_timeout: 0.001
|
||||
timeout: 1
|
||||
rpm: 1
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/my-fake-model
|
||||
api_key: my-fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
## bedrock chat completions
|
||||
- model_name: "*anthropic.claude*"
|
||||
litellm_params:
|
||||
model: bedrock/*anthropic.claude*
|
||||
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
guardrailConfig:
|
||||
"guardrailIdentifier": "h4dsqwhp6j66"
|
||||
"guardrailVersion": "2"
|
||||
"trace": "enabled"
|
||||
|
||||
## bedrock embeddings
|
||||
- model_name: "*amazon.titan-embed-*"
|
||||
litellm_params:
|
||||
model: bedrock/amazon.titan-embed-*
|
||||
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
- model_name: "*cohere.embed-*"
|
||||
litellm_params:
|
||||
model: bedrock/cohere.embed-*
|
||||
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
|
||||
- model_name: "bedrock/*"
|
||||
litellm_params:
|
||||
model: bedrock/*
|
||||
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
|
||||
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
|
||||
aws_region_name: os.environ/AWS_REGION_NAME
|
||||
|
||||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_version: "2023-05-15"
|
||||
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
||||
rpm: 480
|
||||
timeout: 300
|
||||
stream_timeout: 60
|
||||
|
||||
litellm_settings:
|
||||
fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
|
||||
callbacks: ["otel", "prometheus"]
|
||||
default_redis_batch_cache_expiry: 10
|
||||
# default_team_settings:
|
||||
# - team_id: "dbe2f686-a686-4896-864a-4c3924458709"
|
||||
# success_callback: ["langfuse"]
|
||||
# langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1
|
||||
# langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1
|
||||
|
||||
# litellm_settings:
|
||||
# cache: True
|
||||
|
|
|
@ -1308,7 +1308,7 @@ async def update_cache( # noqa: PLR0915
|
|||
await _update_team_cache()
|
||||
|
||||
asyncio.create_task(
|
||||
user_api_key_cache.async_batch_set_cache(
|
||||
user_api_key_cache.async_set_cache_pipeline(
|
||||
cache_list=values_to_update_in_cache,
|
||||
ttl=60,
|
||||
litellm_parent_otel_span=parent_otel_span,
|
||||
|
@ -2978,7 +2978,7 @@ class ProxyStartupEvent:
|
|||
|
||||
if (
|
||||
proxy_logging_obj is not None
|
||||
and proxy_logging_obj.slack_alerting_instance is not None
|
||||
and proxy_logging_obj.slack_alerting_instance.alerting is not None
|
||||
and prisma_client is not None
|
||||
):
|
||||
print("Alerting: Initializing Weekly/Monthly Spend Reports") # noqa
|
||||
|
|
|
@ -175,7 +175,7 @@ class InternalUsageCache:
|
|||
local_only: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
return await self.dual_cache.async_batch_set_cache(
|
||||
return await self.dual_cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list,
|
||||
local_only=local_only,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
|
|
|
@ -339,11 +339,7 @@ class Router:
|
|||
cache_config: Dict[str, Any] = {}
|
||||
|
||||
self.client_ttl = client_ttl
|
||||
if redis_url is not None or (
|
||||
redis_host is not None
|
||||
and redis_port is not None
|
||||
and redis_password is not None
|
||||
):
|
||||
if redis_url is not None or (redis_host is not None and redis_port is not None):
|
||||
cache_type = "redis"
|
||||
|
||||
if redis_url is not None:
|
||||
|
|
|
@ -796,7 +796,7 @@ def client(original_function): # noqa: PLR0915
|
|||
and kwargs.get("_arealtime", False) is not True
|
||||
): # allow users to control returning cached responses from the completion function
|
||||
# checking cache
|
||||
print_verbose("INSIDE CHECKING CACHE")
|
||||
verbose_logger.debug("INSIDE CHECKING SYNC CACHE")
|
||||
caching_handler_response: CachingHandlerResponse = (
|
||||
_llm_caching_handler._sync_get_cache(
|
||||
model=model or "",
|
||||
|
@ -808,6 +808,7 @@ def client(original_function): # noqa: PLR0915
|
|||
args=args,
|
||||
)
|
||||
)
|
||||
|
||||
if caching_handler_response.cached_result is not None:
|
||||
return caching_handler_response.cached_result
|
||||
|
||||
|
|
223
tests/local_testing/cache_unit_tests.py
Normal file
223
tests/local_testing/cache_unit_tests.py
Normal file
|
@ -0,0 +1,223 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from litellm.caching import LiteLLMCacheType
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from test_rerank import assert_response_shape
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import asyncio
|
||||
import hashlib
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm.caching import Cache
|
||||
from litellm import completion, embedding
|
||||
|
||||
|
||||
class LLMCachingUnitTests(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_cache_type(self) -> LiteLLMCacheType:
|
||||
pass
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_completion(self, sync_mode):
|
||||
litellm._turn_on_debug()
|
||||
|
||||
random_number = random.randint(
|
||||
1, 100000
|
||||
) # add a random number to ensure it's always adding / reading from cache
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"write a one sentence poem about: {random_number}",
|
||||
}
|
||||
]
|
||||
|
||||
cache_type = self.get_cache_type()
|
||||
litellm.cache = Cache(
|
||||
type=cache_type,
|
||||
)
|
||||
|
||||
if sync_mode:
|
||||
response1 = completion(
|
||||
"gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
caching=True,
|
||||
max_tokens=20,
|
||||
mock_response="This number is so great!",
|
||||
)
|
||||
else:
|
||||
response1 = await litellm.acompletion(
|
||||
"gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
caching=True,
|
||||
max_tokens=20,
|
||||
mock_response="This number is so great!",
|
||||
)
|
||||
# response2 is mocked to a different response from response1,
|
||||
# but the completion from the cache should be used instead of the mock
|
||||
# response since the input is the same as response1
|
||||
await asyncio.sleep(0.5)
|
||||
if sync_mode:
|
||||
response2 = completion(
|
||||
"gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
caching=True,
|
||||
max_tokens=20,
|
||||
mock_response="This number is great!",
|
||||
)
|
||||
else:
|
||||
response2 = await litellm.acompletion(
|
||||
"gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
caching=True,
|
||||
max_tokens=20,
|
||||
mock_response="This number is great!",
|
||||
)
|
||||
if (
|
||||
response1["choices"][0]["message"]["content"]
|
||||
!= response2["choices"][0]["message"]["content"]
|
||||
): # 1 and 2 should be the same
|
||||
# 1&2 have the exact same input params. This MUST Be a CACHE HIT
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
pytest.fail(
|
||||
f"Error occurred: response1 - {response1['choices'][0]['message']['content']} != response2 - {response2['choices'][0]['message']['content']}"
|
||||
)
|
||||
# Since the parameters are not the same as response1, response3 should actually
|
||||
# be the mock response
|
||||
if sync_mode:
|
||||
response3 = completion(
|
||||
"gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
caching=True,
|
||||
temperature=0.5,
|
||||
mock_response="This number is awful!",
|
||||
)
|
||||
else:
|
||||
response3 = await litellm.acompletion(
|
||||
"gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
caching=True,
|
||||
temperature=0.5,
|
||||
mock_response="This number is awful!",
|
||||
)
|
||||
|
||||
print("\nresponse 1", response1)
|
||||
print("\nresponse 2", response2)
|
||||
print("\nresponse 3", response3)
|
||||
# print("\nresponse 4", response4)
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
|
||||
# 1 & 2 should be exactly the same
|
||||
# 1 & 3 should be different, since input params are diff
|
||||
|
||||
if (
|
||||
response1["choices"][0]["message"]["content"]
|
||||
== response3["choices"][0]["message"]["content"]
|
||||
):
|
||||
# if input params like max_tokens, temperature are diff it should NOT be a cache hit
|
||||
print(f"response1: {response1}")
|
||||
print(f"response3: {response3}")
|
||||
pytest.fail(
|
||||
f"Response 1 == response 3. Same model, diff params shoudl not cache Error"
|
||||
f" occurred:"
|
||||
)
|
||||
|
||||
assert response1.id == response2.id
|
||||
assert response1.created == response2.created
|
||||
assert (
|
||||
response1.choices[0].message.content == response2.choices[0].message.content
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_disk_cache_embedding(self, sync_mode):
|
||||
litellm._turn_on_debug()
|
||||
|
||||
random_number = random.randint(
|
||||
1, 100000
|
||||
) # add a random number to ensure it's always adding / reading from cache
|
||||
input = [f"hello {random_number}"]
|
||||
litellm.cache = Cache(
|
||||
type="disk",
|
||||
)
|
||||
|
||||
if sync_mode:
|
||||
response1 = embedding(
|
||||
"openai/text-embedding-ada-002",
|
||||
input=input,
|
||||
caching=True,
|
||||
)
|
||||
else:
|
||||
response1 = await litellm.aembedding(
|
||||
"openai/text-embedding-ada-002",
|
||||
input=input,
|
||||
caching=True,
|
||||
)
|
||||
# response2 is mocked to a different response from response1,
|
||||
# but the completion from the cache should be used instead of the mock
|
||||
# response since the input is the same as response1
|
||||
await asyncio.sleep(0.5)
|
||||
if sync_mode:
|
||||
response2 = embedding(
|
||||
"openai/text-embedding-ada-002",
|
||||
input=input,
|
||||
caching=True,
|
||||
)
|
||||
else:
|
||||
response2 = await litellm.aembedding(
|
||||
"openai/text-embedding-ada-002",
|
||||
input=input,
|
||||
caching=True,
|
||||
)
|
||||
|
||||
if response2._hidden_params["cache_hit"] is not True:
|
||||
pytest.fail("Cache hit should be True")
|
||||
|
||||
# Since the parameters are not the same as response1, response3 should actually
|
||||
# be the mock response
|
||||
if sync_mode:
|
||||
response3 = embedding(
|
||||
"openai/text-embedding-ada-002",
|
||||
input=input,
|
||||
user="charlie",
|
||||
caching=True,
|
||||
)
|
||||
else:
|
||||
response3 = await litellm.aembedding(
|
||||
"openai/text-embedding-ada-002",
|
||||
input=input,
|
||||
caching=True,
|
||||
user="charlie",
|
||||
)
|
||||
|
||||
print("\nresponse 1", response1)
|
||||
print("\nresponse 2", response2)
|
||||
print("\nresponse 3", response3)
|
||||
# print("\nresponse 4", response4)
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
|
||||
# 1 & 2 should be exactly the same
|
||||
# 1 & 3 should be different, since input params are diff
|
||||
|
||||
if response3._hidden_params.get("cache_hit") is True:
|
||||
pytest.fail("Cache hit should not be True")
|
|
@ -438,7 +438,7 @@ async def test_send_daily_reports_ignores_zero_values():
|
|||
slack_alerting.internal_usage_cache.async_batch_get_cache = AsyncMock(
|
||||
return_value=[None, 0, 10, 0, 0, None]
|
||||
)
|
||||
slack_alerting.internal_usage_cache.async_batch_set_cache = AsyncMock()
|
||||
slack_alerting.internal_usage_cache.async_set_cache_pipeline = AsyncMock()
|
||||
|
||||
router.get_model_info.side_effect = lambda x: {"litellm_params": {"model": x}}
|
||||
|
||||
|
|
|
@ -1103,81 +1103,6 @@ async def test_redis_cache_acompletion_stream_bedrock():
|
|||
raise e
|
||||
|
||||
|
||||
def test_disk_cache_completion():
|
||||
litellm.set_verbose = False
|
||||
|
||||
random_number = random.randint(
|
||||
1, 100000
|
||||
) # add a random number to ensure it's always adding / reading from cache
|
||||
messages = [
|
||||
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
|
||||
]
|
||||
litellm.cache = Cache(
|
||||
type="disk",
|
||||
)
|
||||
|
||||
response1 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
caching=True,
|
||||
max_tokens=20,
|
||||
mock_response="This number is so great!",
|
||||
)
|
||||
# response2 is mocked to a different response from response1,
|
||||
# but the completion from the cache should be used instead of the mock
|
||||
# response since the input is the same as response1
|
||||
response2 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
caching=True,
|
||||
max_tokens=20,
|
||||
mock_response="This number is awful!",
|
||||
)
|
||||
# Since the parameters are not the same as response1, response3 should actually
|
||||
# be the mock response
|
||||
response3 = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
caching=True,
|
||||
temperature=0.5,
|
||||
mock_response="This number is awful!",
|
||||
)
|
||||
|
||||
print("\nresponse 1", response1)
|
||||
print("\nresponse 2", response2)
|
||||
print("\nresponse 3", response3)
|
||||
# print("\nresponse 4", response4)
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
|
||||
# 1 & 2 should be exactly the same
|
||||
# 1 & 3 should be different, since input params are diff
|
||||
if (
|
||||
response1["choices"][0]["message"]["content"]
|
||||
!= response2["choices"][0]["message"]["content"]
|
||||
): # 1 and 2 should be the same
|
||||
# 1&2 have the exact same input params. This MUST Be a CACHE HIT
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
pytest.fail(f"Error occurred:")
|
||||
if (
|
||||
response1["choices"][0]["message"]["content"]
|
||||
== response3["choices"][0]["message"]["content"]
|
||||
):
|
||||
# if input params like max_tokens, temperature are diff it should NOT be a cache hit
|
||||
print(f"response1: {response1}")
|
||||
print(f"response3: {response3}")
|
||||
pytest.fail(
|
||||
f"Response 1 == response 3. Same model, diff params shoudl not cache Error"
|
||||
f" occurred:"
|
||||
)
|
||||
|
||||
assert response1.id == response2.id
|
||||
assert response1.created == response2.created
|
||||
assert response1.choices[0].message.content == response2.choices[0].message.content
|
||||
|
||||
|
||||
# @pytest.mark.skip(reason="AWS Suspended Account")
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
|
|
11
tests/local_testing/test_disk_cache_unit_tests.py
Normal file
11
tests/local_testing/test_disk_cache_unit_tests.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
from cache_unit_tests import LLMCachingUnitTests
|
||||
from litellm.caching import LiteLLMCacheType
|
||||
|
||||
|
||||
class TestDiskCacheUnitTests(LLMCachingUnitTests):
|
||||
def get_cache_type(self) -> LiteLLMCacheType:
|
||||
return LiteLLMCacheType.DISK
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# pytest.main([__file__, "-v", "-s"])
|
|
@ -146,7 +146,7 @@ async def test_dual_cache_batch_operations(is_async):
|
|||
|
||||
# Set values
|
||||
if is_async:
|
||||
await dual_cache.async_batch_set_cache(cache_list)
|
||||
await dual_cache.async_set_cache_pipeline(cache_list)
|
||||
else:
|
||||
for key, value in cache_list:
|
||||
dual_cache.set_cache(key, value)
|
||||
|
|
|
@ -212,26 +212,48 @@ def test_get_langfuse_logger_for_request_with_cached_logger():
|
|||
assert result == cached_logger
|
||||
mock_cache.get_cache.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize("metadata", [
|
||||
{'a': 1, 'b': 2, 'c': 3},
|
||||
{'a': {'nested_a': 1}, 'b': {'nested_b': 2}},
|
||||
{'a': [1, 2, 3], 'b': {4, 5, 6}},
|
||||
{'a': (1, 2), 'b': frozenset([3, 4]), 'c': {'d': [5, 6]}},
|
||||
{'lock': threading.Lock()},
|
||||
{'func': lambda x: x + 1},
|
||||
{
|
||||
'int': 42,
|
||||
'str': 'hello',
|
||||
'list': [1, 2, 3],
|
||||
'set': {4, 5},
|
||||
'dict': {'nested': 'value'},
|
||||
'non_copyable': threading.Lock(),
|
||||
'function': print
|
||||
},
|
||||
['list', 'not', 'a', 'dict'],
|
||||
{'timestamp': datetime.now()},
|
||||
{},
|
||||
None,
|
||||
])
|
||||
def test_langfuse_logger_prepare_metadata(metadata):
|
||||
global_langfuse_logger._prepare_metadata(metadata)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metadata, expected_metadata",
|
||||
[
|
||||
({"a": 1, "b": 2, "c": 3}, {"a": 1, "b": 2, "c": 3}),
|
||||
(
|
||||
{"a": {"nested_a": 1}, "b": {"nested_b": 2}},
|
||||
{"a": {"nested_a": 1}, "b": {"nested_b": 2}},
|
||||
),
|
||||
({"a": [1, 2, 3], "b": {4, 5, 6}}, {"a": [1, 2, 3], "b": {4, 5, 6}}),
|
||||
(
|
||||
{"a": (1, 2), "b": frozenset([3, 4]), "c": {"d": [5, 6]}},
|
||||
{"a": (1, 2), "b": frozenset([3, 4]), "c": {"d": [5, 6]}},
|
||||
),
|
||||
({"lock": threading.Lock()}, {}),
|
||||
({"func": lambda x: x + 1}, {}),
|
||||
(
|
||||
{
|
||||
"int": 42,
|
||||
"str": "hello",
|
||||
"list": [1, 2, 3],
|
||||
"set": {4, 5},
|
||||
"dict": {"nested": "value"},
|
||||
"non_copyable": threading.Lock(),
|
||||
"function": print,
|
||||
},
|
||||
{
|
||||
"int": 42,
|
||||
"str": "hello",
|
||||
"list": [1, 2, 3],
|
||||
"set": {4, 5},
|
||||
"dict": {"nested": "value"},
|
||||
},
|
||||
),
|
||||
(
|
||||
{"list": ["list", "not", "a", "dict"]},
|
||||
{"list": ["list", "not", "a", "dict"]},
|
||||
),
|
||||
({}, {}),
|
||||
(None, None),
|
||||
],
|
||||
)
|
||||
def test_langfuse_logger_prepare_metadata(metadata, expected_metadata):
|
||||
result = global_langfuse_logger._prepare_metadata(metadata)
|
||||
assert result == expected_metadata
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue