LiteLLM Minor Fixes & Improvements (11/12/2024) (#6705)

* fix(caching): convert arg to equivalent kwargs in llm caching handler

prevent unexpected errors

* fix(caching_handler.py): don't pass args to caching

* fix(caching): remove all *args from caching.py

* fix(caching): consistent function signatures + abc method

* test(caching_unit_tests.py): add unit tests for llm caching

ensures coverage for common caching scenarios across different implementations

* refactor(litellm_logging.py): move to using cache key from hidden params instead of regenerating one

* fix(router.py): drop redis password requirement

* fix(proxy_server.py): fix faulty slack alerting check

* fix(langfuse.py): avoid copying functions/thread lock objects in metadata

fixes metadata copy error when parent otel span in metadata

* test: update test
This commit is contained in:
Krish Dholakia 2024-11-12 22:50:51 +05:30 committed by GitHub
parent d39fd60801
commit 9160d80fa5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 525 additions and 204 deletions

View file

@ -8,6 +8,7 @@ Has 4 methods:
- async_get_cache
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
@ -18,7 +19,7 @@ else:
Span = Any
class BaseCache:
class BaseCache(ABC):
def __init__(self, default_ttl: int = 60):
self.default_ttl = default_ttl
@ -37,6 +38,10 @@ class BaseCache:
async def async_set_cache(self, key, value, **kwargs):
raise NotImplementedError
@abstractmethod
async def async_set_cache_pipeline(self, cache_list, **kwargs):
pass
def get_cache(self, key, **kwargs):
raise NotImplementedError

View file

@ -233,19 +233,18 @@ class Cache:
if self.namespace is not None and isinstance(self.cache, RedisCache):
self.cache.namespace = self.namespace
def get_cache_key(self, *args, **kwargs) -> str:
def get_cache_key(self, **kwargs) -> str:
"""
Get the cache key for the given arguments.
Args:
*args: args to litellm.completion() or embedding()
**kwargs: kwargs to litellm.completion() or embedding()
Returns:
str: The cache key generated from the arguments, or None if no cache key could be generated.
"""
cache_key = ""
verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
# verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
if preset_cache_key is not None:
@ -521,7 +520,7 @@ class Cache:
return cached_response
return cached_result
def get_cache(self, *args, **kwargs):
def get_cache(self, **kwargs):
"""
Retrieves the cached result for the given arguments.
@ -533,13 +532,13 @@ class Cache:
The cached result if it exists, otherwise None.
"""
try: # never block execution
if self.should_use_cache(*args, **kwargs) is not True:
if self.should_use_cache(**kwargs) is not True:
return
messages = kwargs.get("messages", [])
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(*args, **kwargs)
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
@ -553,29 +552,28 @@ class Cache:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
async def async_get_cache(self, *args, **kwargs):
async def async_get_cache(self, **kwargs):
"""
Async get cache implementation.
Used for embedding calls in async wrapper
"""
try: # never block execution
if self.should_use_cache(*args, **kwargs) is not True:
if self.should_use_cache(**kwargs) is not True:
return
kwargs.get("messages", [])
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(*args, **kwargs)
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = await self.cache.async_get_cache(
cache_key, *args, **kwargs
)
cached_result = await self.cache.async_get_cache(cache_key, **kwargs)
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
@ -583,7 +581,7 @@ class Cache:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
def _add_cache_logic(self, result, *args, **kwargs):
def _add_cache_logic(self, result, **kwargs):
"""
Common implementation across sync + async add_cache functions
"""
@ -591,7 +589,7 @@ class Cache:
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(*args, **kwargs)
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
if isinstance(result, BaseModel):
result = result.model_dump_json()
@ -613,7 +611,7 @@ class Cache:
except Exception as e:
raise e
def add_cache(self, result, *args, **kwargs):
def add_cache(self, result, **kwargs):
"""
Adds a result to the cache.
@ -625,41 +623,42 @@ class Cache:
None
"""
try:
if self.should_use_cache(*args, **kwargs) is not True:
if self.should_use_cache(**kwargs) is not True:
return
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, *args, **kwargs
result=result, **kwargs
)
self.cache.set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
async def async_add_cache(self, result, *args, **kwargs):
async def async_add_cache(self, result, **kwargs):
"""
Async implementation of add_cache
"""
try:
if self.should_use_cache(*args, **kwargs) is not True:
if self.should_use_cache(**kwargs) is not True:
return
if self.type == "redis" and self.redis_flush_size is not None:
# high traffic - fill in results in memory and then flush
await self.batch_cache_write(result, *args, **kwargs)
await self.batch_cache_write(result, **kwargs)
else:
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, *args, **kwargs
result=result, **kwargs
)
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
async def async_add_cache_pipeline(self, result, *args, **kwargs):
async def async_add_cache_pipeline(self, result, **kwargs):
"""
Async implementation of add_cache for Embedding calls
Does a bulk write, to prevent using too many clients
"""
try:
if self.should_use_cache(*args, **kwargs) is not True:
if self.should_use_cache(**kwargs) is not True:
return
# set default ttl if not set
@ -668,29 +667,27 @@ class Cache:
cache_list = []
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i})
preset_cache_key = self.get_cache_key(**{**kwargs, "input": i})
kwargs["cache_key"] = preset_cache_key
embedding_response = result.data[idx]
cache_key, cached_data, kwargs = self._add_cache_logic(
result=embedding_response,
*args,
**kwargs,
)
cache_list.append((cache_key, cached_data))
async_set_cache_pipeline = getattr(
self.cache, "async_set_cache_pipeline", None
)
if async_set_cache_pipeline:
await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
else:
tasks = []
for val in cache_list:
tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)
await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs)
# if async_set_cache_pipeline:
# await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
# else:
# tasks = []
# for val in cache_list:
# tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
# await asyncio.gather(*tasks)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
def should_use_cache(self, *args, **kwargs):
def should_use_cache(self, **kwargs):
"""
Returns true if we should use the cache for LLM API calls
@ -708,10 +705,8 @@ class Cache:
return True
return False
async def batch_cache_write(self, result, *args, **kwargs):
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, *args, **kwargs
)
async def batch_cache_write(self, result, **kwargs):
cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
async def ping(self):

View file

@ -137,7 +137,7 @@ class LLMCachingHandler:
if litellm.cache is not None and self._is_call_type_supported_by_cache(
original_function=original_function
):
print_verbose("Checking Cache")
verbose_logger.debug("Checking Cache")
cached_result = await self._retrieve_from_cache(
call_type=call_type,
kwargs=kwargs,
@ -145,7 +145,7 @@ class LLMCachingHandler:
)
if cached_result is not None and not isinstance(cached_result, list):
print_verbose("Cache Hit!")
verbose_logger.debug("Cache Hit!")
cache_hit = True
end_time = datetime.datetime.now()
model, _, _, _ = litellm.get_llm_provider(
@ -215,6 +215,7 @@ class LLMCachingHandler:
final_embedding_cached_response=final_embedding_cached_response,
embedding_all_elements_cache_hit=embedding_all_elements_cache_hit,
)
verbose_logger.debug(f"CACHE RESULT: {cached_result}")
return CachingHandlerResponse(
cached_result=cached_result,
final_embedding_cached_response=final_embedding_cached_response,
@ -233,12 +234,19 @@ class LLMCachingHandler:
from litellm.utils import CustomStreamWrapper
args = args or ()
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
self.original_function,
args,
)
)
cached_result: Optional[Any] = None
if litellm.cache is not None and self._is_call_type_supported_by_cache(
original_function=original_function
):
print_verbose("Checking Cache")
cached_result = litellm.cache.get_cache(*args, **kwargs)
cached_result = litellm.cache.get_cache(**new_kwargs)
if cached_result is not None:
if "detail" in cached_result:
# implies an error occurred
@ -475,14 +483,21 @@ class LLMCachingHandler:
if litellm.cache is None:
return None
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
self.original_function,
args,
)
)
cached_result: Optional[Any] = None
if call_type == CallTypes.aembedding.value and isinstance(
kwargs["input"], list
new_kwargs["input"], list
):
tasks = []
for idx, i in enumerate(kwargs["input"]):
for idx, i in enumerate(new_kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key(
*args, **{**kwargs, "input": i}
**{**new_kwargs, "input": i}
)
tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key))
cached_result = await asyncio.gather(*tasks)
@ -493,9 +508,9 @@ class LLMCachingHandler:
cached_result = None
else:
if litellm.cache._supports_async() is True:
cached_result = await litellm.cache.async_get_cache(*args, **kwargs)
cached_result = await litellm.cache.async_get_cache(**new_kwargs)
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
cached_result = litellm.cache.get_cache(*args, **kwargs)
cached_result = litellm.cache.get_cache(**new_kwargs)
return cached_result
def _convert_cached_result_to_model_response(
@ -580,6 +595,7 @@ class LLMCachingHandler:
model_response_object=EmbeddingResponse(),
response_type="embedding",
)
elif (
call_type == CallTypes.arerank.value or call_type == CallTypes.rerank.value
) and isinstance(cached_result, dict):
@ -603,6 +619,13 @@ class LLMCachingHandler:
response_type="audio_transcription",
hidden_params=hidden_params,
)
if (
hasattr(cached_result, "_hidden_params")
and cached_result._hidden_params is not None
and isinstance(cached_result._hidden_params, dict)
):
cached_result._hidden_params["cache_hit"] = True
return cached_result
def _convert_cached_stream_response(
@ -658,12 +681,19 @@ class LLMCachingHandler:
Raises:
None
"""
kwargs.update(convert_args_to_kwargs(result, original_function, kwargs, args))
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
original_function,
args,
)
)
if litellm.cache is None:
return
# [OPTIONAL] ADD TO CACHE
if self._should_store_result_in_cache(
original_function=original_function, kwargs=kwargs
original_function=original_function, kwargs=new_kwargs
):
if (
isinstance(result, litellm.ModelResponse)
@ -673,29 +703,29 @@ class LLMCachingHandler:
):
if (
isinstance(result, EmbeddingResponse)
and isinstance(kwargs["input"], list)
and isinstance(new_kwargs["input"], list)
and litellm.cache is not None
and not isinstance(
litellm.cache.cache, S3Cache
) # s3 doesn't support bulk writing. Exclude.
):
asyncio.create_task(
litellm.cache.async_add_cache_pipeline(result, **kwargs)
litellm.cache.async_add_cache_pipeline(result, **new_kwargs)
)
elif isinstance(litellm.cache.cache, S3Cache):
threading.Thread(
target=litellm.cache.add_cache,
args=(result,),
kwargs=kwargs,
kwargs=new_kwargs,
).start()
else:
asyncio.create_task(
litellm.cache.async_add_cache(
result.model_dump_json(), **kwargs
result.model_dump_json(), **new_kwargs
)
)
else:
asyncio.create_task(litellm.cache.async_add_cache(result, **kwargs))
asyncio.create_task(litellm.cache.async_add_cache(result, **new_kwargs))
def sync_set_cache(
self,
@ -706,16 +736,20 @@ class LLMCachingHandler:
"""
Sync internal method to add the result to the cache
"""
kwargs.update(
convert_args_to_kwargs(result, self.original_function, kwargs, args)
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
self.original_function,
args,
)
)
if litellm.cache is None:
return
if self._should_store_result_in_cache(
original_function=self.original_function, kwargs=kwargs
original_function=self.original_function, kwargs=new_kwargs
):
litellm.cache.add_cache(result, **kwargs)
litellm.cache.add_cache(result, **new_kwargs)
return
@ -865,9 +899,7 @@ class LLMCachingHandler:
def convert_args_to_kwargs(
result: Any,
original_function: Callable,
kwargs: Dict[str, Any],
args: Optional[Tuple[Any, ...]] = None,
) -> Dict[str, Any]:
# Get the signature of the original function

View file

@ -24,7 +24,6 @@ class DiskCache(BaseCache):
self.disk_cache = dc.Cache(disk_cache_dir)
def set_cache(self, key, value, **kwargs):
print_verbose("DiskCache: set_cache")
if "ttl" in kwargs:
self.disk_cache.set(key, value, expire=kwargs["ttl"])
else:
@ -33,10 +32,10 @@ class DiskCache(BaseCache):
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, ttl=None):
async def async_set_cache_pipeline(self, cache_list, **kwargs):
for cache_key, cache_value in cache_list:
if ttl is not None:
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
if "ttl" in kwargs:
self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"])
else:
self.set_cache(key=cache_key, value=cache_value)

View file

@ -314,7 +314,8 @@ class DualCache(BaseCache):
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
)
async def async_batch_set_cache(
# async_batch_set_cache
async def async_set_cache_pipeline(
self, cache_list: list, local_only: bool = False, **kwargs
):
"""

View file

@ -9,6 +9,7 @@ Has 4 methods:
"""
import ast
import asyncio
import json
from typing import Any
@ -422,3 +423,9 @@ class QdrantSemanticCache(BaseCache):
async def _collection_info(self):
return self.collection_info
async def async_set_cache_pipeline(self, cache_list, **kwargs):
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)

View file

@ -404,7 +404,7 @@ class RedisCache(BaseCache):
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
)
)
return results
return None
except Exception as e:
## LOGGING ##
end_time = time.time()

View file

@ -9,6 +9,7 @@ Has 4 methods:
"""
import ast
import asyncio
import json
from typing import Any
@ -331,3 +332,9 @@ class RedisSemanticCache(BaseCache):
async def _index_info(self):
return await self.index.ainfo()
async def async_set_cache_pipeline(self, cache_list, **kwargs):
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)

View file

@ -10,6 +10,7 @@ Has 4 methods:
"""
import ast
import asyncio
import json
from typing import Any, Optional
@ -153,3 +154,9 @@ class S3Cache(BaseCache):
async def disconnect(self):
pass
async def async_set_cache_pipeline(self, cache_list, **kwargs):
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)

View file

@ -423,7 +423,7 @@ class SlackAlerting(CustomBatchLogger):
latency_cache_keys = [(key, 0) for key in latency_keys]
failed_request_cache_keys = [(key, 0) for key in failed_request_keys]
combined_metrics_cache_keys = latency_cache_keys + failed_request_cache_keys
await self.internal_usage_cache.async_batch_set_cache(
await self.internal_usage_cache.async_set_cache_pipeline(
cache_list=combined_metrics_cache_keys
)

View file

@ -3,8 +3,9 @@
import copy
import os
import traceback
import types
from collections.abc import MutableMapping, MutableSequence, MutableSet
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional, cast
from packaging.version import Version
from pydantic import BaseModel
@ -355,17 +356,28 @@ class LangFuseLogger:
)
)
def _prepare_metadata(self, metadata) -> Any:
def is_base_type(self, value: Any) -> bool:
# Check if the value is of a base type
base_types = (int, float, str, bool, list, dict, tuple)
return isinstance(value, base_types)
def _prepare_metadata(self, metadata: Optional[dict]) -> Any:
try:
return copy.deepcopy(metadata) # Avoid modifying the original metadata
except (TypeError, copy.Error) as e:
verbose_logger.warning(f"Langfuse Layer Error - {e}")
if metadata is None:
return None
# Filter out function types from the metadata
sanitized_metadata = {k: v for k, v in metadata.items() if not callable(v)}
return copy.deepcopy(sanitized_metadata)
except Exception as e:
verbose_logger.debug(f"Langfuse Layer Error - {e}, metadata: {metadata}")
new_metadata: Dict[str, Any] = {}
# if metadata is not a MutableMapping, return an empty dict since we can't call items() on it
if not isinstance(metadata, MutableMapping):
verbose_logger.warning(
verbose_logger.debug(
"Langfuse Layer Logging - metadata is not a MutableMapping, returning empty dict"
)
return new_metadata
@ -373,25 +385,40 @@ class LangFuseLogger:
for key, value in metadata.items():
try:
if isinstance(value, MutableMapping):
new_metadata[key] = self._prepare_metadata(value)
elif isinstance(value, (MutableSequence, MutableSet)):
new_metadata[key] = type(value)(
*(
new_metadata[key] = self._prepare_metadata(cast(dict, value))
elif isinstance(value, MutableSequence):
# For lists or other mutable sequences
new_metadata[key] = list(
(
self._prepare_metadata(v)
self._prepare_metadata(cast(dict, v))
if isinstance(v, MutableMapping)
else copy.deepcopy(v)
)
for v in value
)
elif isinstance(value, MutableSet):
# For sets specifically, create a new set by passing an iterable
new_metadata[key] = set(
(
self._prepare_metadata(cast(dict, v))
if isinstance(v, MutableMapping)
else copy.deepcopy(v)
)
for v in value
)
elif isinstance(value, BaseModel):
new_metadata[key] = value.model_dump()
elif self.is_base_type(value):
new_metadata[key] = value
else:
new_metadata[key] = copy.deepcopy(value)
verbose_logger.debug(
f"Langfuse Layer Error - Unsupported metadata type: {type(value)} for key: {key}"
)
continue
except (TypeError, copy.Error):
verbose_logger.warning(
f"Langfuse Layer Error - Couldn't copy metadata key: {key} - {traceback.format_exc()}"
verbose_logger.debug(
f"Langfuse Layer Error - Couldn't copy metadata key: {key}, type of key: {type(key)}, type of value: {type(value)} - {traceback.format_exc()}"
)
return new_metadata

View file

@ -2774,11 +2774,6 @@ def get_standard_logging_object_payload(
metadata=metadata
)
if litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(**kwargs)
else:
cache_key = None
saved_cache_cost: float = 0.0
if cache_hit is True:
@ -2820,7 +2815,7 @@ def get_standard_logging_object_payload(
completionStartTime=completion_start_time_float,
model=kwargs.get("model", "") or "",
metadata=clean_metadata,
cache_key=cache_key,
cache_key=clean_hidden_params["cache_key"],
response_cost=response_cost,
total_tokens=usage.total_tokens,
prompt_tokens=usage.prompt_tokens,

View file

@ -1,12 +1,80 @@
model_list:
- model_name: "*"
litellm_params:
model: "*"
model: claude-3-5-sonnet-20240620
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: claude-3-5-sonnet-aihubmix
litellm_params:
model: openai/claude-3-5-sonnet-20240620
input_cost_per_token: 0.000003 # 3$/M
output_cost_per_token: 0.000015 # 15$/M
api_base: "https://exampleopenaiendpoint-production.up.railway.app"
api_key: my-fake-key
- model_name: fake-openai-endpoint-2
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
stream_timeout: 0.001
timeout: 1
rpm: 1
- model_name: fake-openai-endpoint
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
## bedrock chat completions
- model_name: "*anthropic.claude*"
litellm_params:
model: bedrock/*anthropic.claude*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
guardrailConfig:
"guardrailIdentifier": "h4dsqwhp6j66"
"guardrailVersion": "2"
"trace": "enabled"
## bedrock embeddings
- model_name: "*amazon.titan-embed-*"
litellm_params:
model: bedrock/amazon.titan-embed-*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: "*cohere.embed-*"
litellm_params:
model: bedrock/cohere.embed-*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: "bedrock/*"
litellm_params:
model: bedrock/*
aws_access_key_id: os.environ/BEDROCK_AWS_ACCESS_KEY_ID
aws_secret_access_key: os.environ/BEDROCK_AWS_SECRET_ACCESS_KEY
aws_region_name: os.environ/AWS_REGION_NAME
- model_name: gpt-4
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
rpm: 480
timeout: 300
stream_timeout: 60
litellm_settings:
fallbacks: [{ "claude-3-5-sonnet-20240620": ["claude-3-5-sonnet-aihubmix"] }]
callbacks: ["otel", "prometheus"]
default_redis_batch_cache_expiry: 10
# default_team_settings:
# - team_id: "dbe2f686-a686-4896-864a-4c3924458709"
# success_callback: ["langfuse"]
# langfuse_public_key: os.environ/LANGFUSE_PUB_KEY_1 # Project 1
# langfuse_secret: os.environ/LANGFUSE_PRIVATE_KEY_1 # Project 1
# litellm_settings:
# cache: True

View file

@ -1308,7 +1308,7 @@ async def update_cache( # noqa: PLR0915
await _update_team_cache()
asyncio.create_task(
user_api_key_cache.async_batch_set_cache(
user_api_key_cache.async_set_cache_pipeline(
cache_list=values_to_update_in_cache,
ttl=60,
litellm_parent_otel_span=parent_otel_span,
@ -2978,7 +2978,7 @@ class ProxyStartupEvent:
if (
proxy_logging_obj is not None
and proxy_logging_obj.slack_alerting_instance is not None
and proxy_logging_obj.slack_alerting_instance.alerting is not None
and prisma_client is not None
):
print("Alerting: Initializing Weekly/Monthly Spend Reports") # noqa

View file

@ -175,7 +175,7 @@ class InternalUsageCache:
local_only: bool = False,
**kwargs,
) -> None:
return await self.dual_cache.async_batch_set_cache(
return await self.dual_cache.async_set_cache_pipeline(
cache_list=cache_list,
local_only=local_only,
litellm_parent_otel_span=litellm_parent_otel_span,

View file

@ -339,11 +339,7 @@ class Router:
cache_config: Dict[str, Any] = {}
self.client_ttl = client_ttl
if redis_url is not None or (
redis_host is not None
and redis_port is not None
and redis_password is not None
):
if redis_url is not None or (redis_host is not None and redis_port is not None):
cache_type = "redis"
if redis_url is not None:

View file

@ -796,7 +796,7 @@ def client(original_function): # noqa: PLR0915
and kwargs.get("_arealtime", False) is not True
): # allow users to control returning cached responses from the completion function
# checking cache
print_verbose("INSIDE CHECKING CACHE")
verbose_logger.debug("INSIDE CHECKING SYNC CACHE")
caching_handler_response: CachingHandlerResponse = (
_llm_caching_handler._sync_get_cache(
model=model or "",
@ -808,6 +808,7 @@ def client(original_function): # noqa: PLR0915
args=args,
)
)
if caching_handler_response.cached_result is not None:
return caching_handler_response.cached_result

View file

@ -0,0 +1,223 @@
from abc import ABC, abstractmethod
from litellm.caching import LiteLLMCacheType
import os
import sys
import time
import traceback
import uuid
from dotenv import load_dotenv
from test_rerank import assert_response_shape
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
import hashlib
import random
import pytest
import litellm
from litellm.caching import Cache
from litellm import completion, embedding
class LLMCachingUnitTests(ABC):
@abstractmethod
def get_cache_type(self) -> LiteLLMCacheType:
pass
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_cache_completion(self, sync_mode):
litellm._turn_on_debug()
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
messages = [
{
"role": "user",
"content": f"write a one sentence poem about: {random_number}",
}
]
cache_type = self.get_cache_type()
litellm.cache = Cache(
type=cache_type,
)
if sync_mode:
response1 = completion(
"gpt-3.5-turbo",
messages=messages,
caching=True,
max_tokens=20,
mock_response="This number is so great!",
)
else:
response1 = await litellm.acompletion(
"gpt-3.5-turbo",
messages=messages,
caching=True,
max_tokens=20,
mock_response="This number is so great!",
)
# response2 is mocked to a different response from response1,
# but the completion from the cache should be used instead of the mock
# response since the input is the same as response1
await asyncio.sleep(0.5)
if sync_mode:
response2 = completion(
"gpt-3.5-turbo",
messages=messages,
caching=True,
max_tokens=20,
mock_response="This number is great!",
)
else:
response2 = await litellm.acompletion(
"gpt-3.5-turbo",
messages=messages,
caching=True,
max_tokens=20,
mock_response="This number is great!",
)
if (
response1["choices"][0]["message"]["content"]
!= response2["choices"][0]["message"]["content"]
): # 1 and 2 should be the same
# 1&2 have the exact same input params. This MUST Be a CACHE HIT
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(
f"Error occurred: response1 - {response1['choices'][0]['message']['content']} != response2 - {response2['choices'][0]['message']['content']}"
)
# Since the parameters are not the same as response1, response3 should actually
# be the mock response
if sync_mode:
response3 = completion(
"gpt-3.5-turbo",
messages=messages,
caching=True,
temperature=0.5,
mock_response="This number is awful!",
)
else:
response3 = await litellm.acompletion(
"gpt-3.5-turbo",
messages=messages,
caching=True,
temperature=0.5,
mock_response="This number is awful!",
)
print("\nresponse 1", response1)
print("\nresponse 2", response2)
print("\nresponse 3", response3)
# print("\nresponse 4", response4)
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
# 1 & 2 should be exactly the same
# 1 & 3 should be different, since input params are diff
if (
response1["choices"][0]["message"]["content"]
== response3["choices"][0]["message"]["content"]
):
# if input params like max_tokens, temperature are diff it should NOT be a cache hit
print(f"response1: {response1}")
print(f"response3: {response3}")
pytest.fail(
f"Response 1 == response 3. Same model, diff params shoudl not cache Error"
f" occurred:"
)
assert response1.id == response2.id
assert response1.created == response2.created
assert (
response1.choices[0].message.content == response2.choices[0].message.content
)
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_disk_cache_embedding(self, sync_mode):
litellm._turn_on_debug()
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
input = [f"hello {random_number}"]
litellm.cache = Cache(
type="disk",
)
if sync_mode:
response1 = embedding(
"openai/text-embedding-ada-002",
input=input,
caching=True,
)
else:
response1 = await litellm.aembedding(
"openai/text-embedding-ada-002",
input=input,
caching=True,
)
# response2 is mocked to a different response from response1,
# but the completion from the cache should be used instead of the mock
# response since the input is the same as response1
await asyncio.sleep(0.5)
if sync_mode:
response2 = embedding(
"openai/text-embedding-ada-002",
input=input,
caching=True,
)
else:
response2 = await litellm.aembedding(
"openai/text-embedding-ada-002",
input=input,
caching=True,
)
if response2._hidden_params["cache_hit"] is not True:
pytest.fail("Cache hit should be True")
# Since the parameters are not the same as response1, response3 should actually
# be the mock response
if sync_mode:
response3 = embedding(
"openai/text-embedding-ada-002",
input=input,
user="charlie",
caching=True,
)
else:
response3 = await litellm.aembedding(
"openai/text-embedding-ada-002",
input=input,
caching=True,
user="charlie",
)
print("\nresponse 1", response1)
print("\nresponse 2", response2)
print("\nresponse 3", response3)
# print("\nresponse 4", response4)
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
# 1 & 2 should be exactly the same
# 1 & 3 should be different, since input params are diff
if response3._hidden_params.get("cache_hit") is True:
pytest.fail("Cache hit should not be True")

View file

@ -438,7 +438,7 @@ async def test_send_daily_reports_ignores_zero_values():
slack_alerting.internal_usage_cache.async_batch_get_cache = AsyncMock(
return_value=[None, 0, 10, 0, 0, None]
)
slack_alerting.internal_usage_cache.async_batch_set_cache = AsyncMock()
slack_alerting.internal_usage_cache.async_set_cache_pipeline = AsyncMock()
router.get_model_info.side_effect = lambda x: {"litellm_params": {"model": x}}

View file

@ -1103,81 +1103,6 @@ async def test_redis_cache_acompletion_stream_bedrock():
raise e
def test_disk_cache_completion():
litellm.set_verbose = False
random_number = random.randint(
1, 100000
) # add a random number to ensure it's always adding / reading from cache
messages = [
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
]
litellm.cache = Cache(
type="disk",
)
response1 = completion(
model="gpt-3.5-turbo",
messages=messages,
caching=True,
max_tokens=20,
mock_response="This number is so great!",
)
# response2 is mocked to a different response from response1,
# but the completion from the cache should be used instead of the mock
# response since the input is the same as response1
response2 = completion(
model="gpt-3.5-turbo",
messages=messages,
caching=True,
max_tokens=20,
mock_response="This number is awful!",
)
# Since the parameters are not the same as response1, response3 should actually
# be the mock response
response3 = completion(
model="gpt-3.5-turbo",
messages=messages,
caching=True,
temperature=0.5,
mock_response="This number is awful!",
)
print("\nresponse 1", response1)
print("\nresponse 2", response2)
print("\nresponse 3", response3)
# print("\nresponse 4", response4)
litellm.cache = None
litellm.success_callback = []
litellm._async_success_callback = []
# 1 & 2 should be exactly the same
# 1 & 3 should be different, since input params are diff
if (
response1["choices"][0]["message"]["content"]
!= response2["choices"][0]["message"]["content"]
): # 1 and 2 should be the same
# 1&2 have the exact same input params. This MUST Be a CACHE HIT
print(f"response1: {response1}")
print(f"response2: {response2}")
pytest.fail(f"Error occurred:")
if (
response1["choices"][0]["message"]["content"]
== response3["choices"][0]["message"]["content"]
):
# if input params like max_tokens, temperature are diff it should NOT be a cache hit
print(f"response1: {response1}")
print(f"response3: {response3}")
pytest.fail(
f"Response 1 == response 3. Same model, diff params shoudl not cache Error"
f" occurred:"
)
assert response1.id == response2.id
assert response1.created == response2.created
assert response1.choices[0].message.content == response2.choices[0].message.content
# @pytest.mark.skip(reason="AWS Suspended Account")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio

View file

@ -0,0 +1,11 @@
from cache_unit_tests import LLMCachingUnitTests
from litellm.caching import LiteLLMCacheType
class TestDiskCacheUnitTests(LLMCachingUnitTests):
def get_cache_type(self) -> LiteLLMCacheType:
return LiteLLMCacheType.DISK
# if __name__ == "__main__":
# pytest.main([__file__, "-v", "-s"])

View file

@ -146,7 +146,7 @@ async def test_dual_cache_batch_operations(is_async):
# Set values
if is_async:
await dual_cache.async_batch_set_cache(cache_list)
await dual_cache.async_set_cache_pipeline(cache_list)
else:
for key, value in cache_list:
dual_cache.set_cache(key, value)

View file

@ -212,26 +212,48 @@ def test_get_langfuse_logger_for_request_with_cached_logger():
assert result == cached_logger
mock_cache.get_cache.assert_called_once()
@pytest.mark.parametrize("metadata", [
{'a': 1, 'b': 2, 'c': 3},
{'a': {'nested_a': 1}, 'b': {'nested_b': 2}},
{'a': [1, 2, 3], 'b': {4, 5, 6}},
{'a': (1, 2), 'b': frozenset([3, 4]), 'c': {'d': [5, 6]}},
{'lock': threading.Lock()},
{'func': lambda x: x + 1},
@pytest.mark.parametrize(
"metadata, expected_metadata",
[
({"a": 1, "b": 2, "c": 3}, {"a": 1, "b": 2, "c": 3}),
(
{"a": {"nested_a": 1}, "b": {"nested_b": 2}},
{"a": {"nested_a": 1}, "b": {"nested_b": 2}},
),
({"a": [1, 2, 3], "b": {4, 5, 6}}, {"a": [1, 2, 3], "b": {4, 5, 6}}),
(
{"a": (1, 2), "b": frozenset([3, 4]), "c": {"d": [5, 6]}},
{"a": (1, 2), "b": frozenset([3, 4]), "c": {"d": [5, 6]}},
),
({"lock": threading.Lock()}, {}),
({"func": lambda x: x + 1}, {}),
(
{
'int': 42,
'str': 'hello',
'list': [1, 2, 3],
'set': {4, 5},
'dict': {'nested': 'value'},
'non_copyable': threading.Lock(),
'function': print
"int": 42,
"str": "hello",
"list": [1, 2, 3],
"set": {4, 5},
"dict": {"nested": "value"},
"non_copyable": threading.Lock(),
"function": print,
},
['list', 'not', 'a', 'dict'],
{'timestamp': datetime.now()},
{},
None,
])
def test_langfuse_logger_prepare_metadata(metadata):
global_langfuse_logger._prepare_metadata(metadata)
{
"int": 42,
"str": "hello",
"list": [1, 2, 3],
"set": {4, 5},
"dict": {"nested": "value"},
},
),
(
{"list": ["list", "not", "a", "dict"]},
{"list": ["list", "not", "a", "dict"]},
),
({}, {}),
(None, None),
],
)
def test_langfuse_logger_prepare_metadata(metadata, expected_metadata):
result = global_langfuse_logger._prepare_metadata(metadata)
assert result == expected_metadata