forked from phoenix/litellm-mirror
(refactor) get_cache_key
to be under 100 LOC function (#6327)
* refactor - use helpers for name space and hashing * use openai to get the relevant supported params * use helpers for getting cache key * fix test caching * use get/set helpers for preset cache keys * make get_cache_key under 100 LOC * fix _get_model_param_value * fix _get_caching_group * fix linting error * add unit testing for get cache key * test_generate_streaming_content
This commit is contained in:
parent
4cbdad9fc5
commit
979e8ea526
5 changed files with 477 additions and 124 deletions
|
@ -17,13 +17,26 @@ import logging
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, List, Literal, Optional, Tuple, Union
|
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
|
from openai.types.audio.transcription_create_params import TranscriptionCreateParams
|
||||||
|
from openai.types.chat.completion_create_params import (
|
||||||
|
CompletionCreateParamsNonStreaming,
|
||||||
|
CompletionCreateParamsStreaming,
|
||||||
|
)
|
||||||
|
from openai.types.completion_create_params import (
|
||||||
|
CompletionCreateParamsNonStreaming as TextCompletionCreateParamsNonStreaming,
|
||||||
|
)
|
||||||
|
from openai.types.completion_create_params import (
|
||||||
|
CompletionCreateParamsStreaming as TextCompletionCreateParamsStreaming,
|
||||||
|
)
|
||||||
|
from openai.types.embedding_create_params import EmbeddingCreateParams
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.types.caching import *
|
from litellm.types.caching import *
|
||||||
|
from litellm.types.rerank import RerankRequest
|
||||||
from litellm.types.utils import all_litellm_params
|
from litellm.types.utils import all_litellm_params
|
||||||
|
|
||||||
from .base_cache import BaseCache
|
from .base_cache import BaseCache
|
||||||
|
@ -220,7 +233,7 @@ class Cache:
|
||||||
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
||||||
self.cache.namespace = self.namespace
|
self.cache.namespace = self.namespace
|
||||||
|
|
||||||
def get_cache_key(self, *args, **kwargs) -> str: # noqa: PLR0915
|
def get_cache_key(self, *args, **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
Get the cache key for the given arguments.
|
Get the cache key for the given arguments.
|
||||||
|
|
||||||
|
@ -232,106 +245,19 @@ class Cache:
|
||||||
str: The cache key generated from the arguments, or None if no cache key could be generated.
|
str: The cache key generated from the arguments, or None if no cache key could be generated.
|
||||||
"""
|
"""
|
||||||
cache_key = ""
|
cache_key = ""
|
||||||
print_verbose(f"\nGetting Cache key. Kwargs: {kwargs}")
|
verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
|
||||||
|
|
||||||
# for streaming, we use preset_cache_key. It's created in wrapper(), we do this because optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
|
||||||
if kwargs.get("litellm_params", {}).get("preset_cache_key", None) is not None:
|
if preset_cache_key is not None:
|
||||||
_preset_cache_key = kwargs.get("litellm_params", {}).get(
|
verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key)
|
||||||
"preset_cache_key", None
|
return preset_cache_key
|
||||||
)
|
|
||||||
print_verbose(f"\nReturning preset cache key: {_preset_cache_key}")
|
|
||||||
return _preset_cache_key
|
|
||||||
|
|
||||||
# sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4]
|
combined_kwargs = self._get_relevant_args_to_use_for_cache_key()
|
||||||
completion_kwargs = [
|
|
||||||
"model",
|
|
||||||
"messages",
|
|
||||||
"prompt",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"n",
|
|
||||||
"stop",
|
|
||||||
"max_tokens",
|
|
||||||
"presence_penalty",
|
|
||||||
"frequency_penalty",
|
|
||||||
"logit_bias",
|
|
||||||
"user",
|
|
||||||
"response_format",
|
|
||||||
"seed",
|
|
||||||
"tools",
|
|
||||||
"tool_choice",
|
|
||||||
"stream",
|
|
||||||
]
|
|
||||||
embedding_only_kwargs = [
|
|
||||||
"input",
|
|
||||||
"encoding_format",
|
|
||||||
] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
|
|
||||||
transcription_only_kwargs = [
|
|
||||||
"file",
|
|
||||||
"language",
|
|
||||||
]
|
|
||||||
rerank_only_kwargs = [
|
|
||||||
"top_n",
|
|
||||||
"rank_fields",
|
|
||||||
"return_documents",
|
|
||||||
"max_chunks_per_doc",
|
|
||||||
"documents",
|
|
||||||
"query",
|
|
||||||
]
|
|
||||||
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
|
|
||||||
combined_kwargs = (
|
|
||||||
completion_kwargs
|
|
||||||
+ embedding_only_kwargs
|
|
||||||
+ transcription_only_kwargs
|
|
||||||
+ rerank_only_kwargs
|
|
||||||
)
|
|
||||||
litellm_param_kwargs = all_litellm_params
|
litellm_param_kwargs = all_litellm_params
|
||||||
for param in kwargs:
|
for param in kwargs:
|
||||||
if param in combined_kwargs:
|
if param in combined_kwargs:
|
||||||
# check if param == model and model_group is passed in, then override model with model_group
|
param_value: Optional[str] = self._get_param_value(param, kwargs)
|
||||||
if param == "model":
|
if param_value is not None:
|
||||||
model_group = None
|
|
||||||
caching_group = None
|
|
||||||
metadata = kwargs.get("metadata", None)
|
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
|
||||||
if metadata is not None:
|
|
||||||
model_group = metadata.get("model_group")
|
|
||||||
model_group = metadata.get("model_group", None)
|
|
||||||
caching_groups = metadata.get("caching_groups", None)
|
|
||||||
if caching_groups:
|
|
||||||
for group in caching_groups:
|
|
||||||
if model_group in group:
|
|
||||||
caching_group = group
|
|
||||||
break
|
|
||||||
if litellm_params is not None:
|
|
||||||
metadata = litellm_params.get("metadata", None)
|
|
||||||
if metadata is not None:
|
|
||||||
model_group = metadata.get("model_group", None)
|
|
||||||
caching_groups = metadata.get("caching_groups", None)
|
|
||||||
if caching_groups:
|
|
||||||
for group in caching_groups:
|
|
||||||
if model_group in group:
|
|
||||||
caching_group = group
|
|
||||||
break
|
|
||||||
param_value = (
|
|
||||||
caching_group or model_group or kwargs[param]
|
|
||||||
) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
|
|
||||||
elif param == "file":
|
|
||||||
file = kwargs.get("file")
|
|
||||||
metadata = kwargs.get("metadata", {})
|
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
|
||||||
|
|
||||||
# get checksum of file content
|
|
||||||
param_value = (
|
|
||||||
metadata.get("file_checksum")
|
|
||||||
or getattr(file, "name", None)
|
|
||||||
or metadata.get("file_name")
|
|
||||||
or litellm_params.get("file_name")
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if kwargs[param] is None:
|
|
||||||
continue # ignore None params
|
|
||||||
param_value = kwargs[param]
|
|
||||||
cache_key += f"{str(param)}: {str(param_value)}"
|
cache_key += f"{str(param)}: {str(param_value)}"
|
||||||
elif (
|
elif (
|
||||||
param not in litellm_param_kwargs
|
param not in litellm_param_kwargs
|
||||||
|
@ -344,19 +270,200 @@ class Cache:
|
||||||
param_value = kwargs[param]
|
param_value = kwargs[param]
|
||||||
cache_key += f"{str(param)}: {str(param_value)}"
|
cache_key += f"{str(param)}: {str(param_value)}"
|
||||||
|
|
||||||
print_verbose(f"\nCreated cache key: {cache_key}")
|
verbose_logger.debug("\nCreated cache key: %s", cache_key)
|
||||||
# Use hashlib to create a sha256 hash of the cache key
|
hashed_cache_key = self._get_hashed_cache_key(cache_key)
|
||||||
|
hashed_cache_key = self._add_redis_namespace_to_cache_key(
|
||||||
|
hashed_cache_key, **kwargs
|
||||||
|
)
|
||||||
|
self._set_preset_cache_key_in_kwargs(
|
||||||
|
preset_cache_key=hashed_cache_key, **kwargs
|
||||||
|
)
|
||||||
|
return hashed_cache_key
|
||||||
|
|
||||||
|
def _get_param_value(
|
||||||
|
self,
|
||||||
|
param: str,
|
||||||
|
kwargs: dict,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the value for the given param from kwargs
|
||||||
|
"""
|
||||||
|
if param == "model":
|
||||||
|
return self._get_model_param_value(kwargs)
|
||||||
|
elif param == "file":
|
||||||
|
return self._get_file_param_value(kwargs)
|
||||||
|
return kwargs[param]
|
||||||
|
|
||||||
|
def _get_model_param_value(self, kwargs: dict) -> str:
|
||||||
|
"""
|
||||||
|
Handles getting the value for the 'model' param from kwargs
|
||||||
|
|
||||||
|
1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups
|
||||||
|
2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router()
|
||||||
|
3. Else use the `model` passed in kwargs
|
||||||
|
"""
|
||||||
|
metadata: Dict = kwargs.get("metadata", {}) or {}
|
||||||
|
litellm_params: Dict = kwargs.get("litellm_params", {}) or {}
|
||||||
|
metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {}
|
||||||
|
model_group: Optional[str] = metadata.get(
|
||||||
|
"model_group"
|
||||||
|
) or metadata_in_litellm_params.get("model_group")
|
||||||
|
caching_group = self._get_caching_group(metadata, model_group)
|
||||||
|
return caching_group or model_group or kwargs["model"]
|
||||||
|
|
||||||
|
def _get_caching_group(
|
||||||
|
self, metadata: dict, model_group: Optional[str]
|
||||||
|
) -> Optional[str]:
|
||||||
|
caching_groups: Optional[List] = metadata.get("caching_groups", [])
|
||||||
|
if caching_groups:
|
||||||
|
for group in caching_groups:
|
||||||
|
if model_group in group:
|
||||||
|
return str(group)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_file_param_value(self, kwargs: dict) -> str:
|
||||||
|
"""
|
||||||
|
Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests
|
||||||
|
"""
|
||||||
|
file = kwargs.get("file")
|
||||||
|
metadata = kwargs.get("metadata", {})
|
||||||
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
|
return (
|
||||||
|
metadata.get("file_checksum")
|
||||||
|
or getattr(file, "name", None)
|
||||||
|
or metadata.get("file_name")
|
||||||
|
or litellm_params.get("file_name")
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the preset cache key from kwargs["litellm_params"]
|
||||||
|
|
||||||
|
We use _get_preset_cache_keys for two reasons
|
||||||
|
|
||||||
|
1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
||||||
|
2. avoid doing duplicate / repeated work
|
||||||
|
"""
|
||||||
|
if kwargs:
|
||||||
|
if "litellm_params" in kwargs:
|
||||||
|
return kwargs["litellm_params"].get("preset_cache_key", None)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Set the calculated cache key in kwargs
|
||||||
|
|
||||||
|
This is used to avoid doing duplicate / repeated work
|
||||||
|
|
||||||
|
Placed in kwargs["litellm_params"]
|
||||||
|
"""
|
||||||
|
if kwargs:
|
||||||
|
if "litellm_params" in kwargs:
|
||||||
|
kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key
|
||||||
|
|
||||||
|
def _get_relevant_args_to_use_for_cache_key(self) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Gets the supported kwargs for each call type and combines them
|
||||||
|
"""
|
||||||
|
chat_completion_kwargs = self._get_litellm_supported_chat_completion_kwargs()
|
||||||
|
text_completion_kwargs = self._get_litellm_supported_text_completion_kwargs()
|
||||||
|
embedding_kwargs = self._get_litellm_supported_embedding_kwargs()
|
||||||
|
transcription_kwargs = self._get_litellm_supported_transcription_kwargs()
|
||||||
|
rerank_kwargs = self._get_litellm_supported_rerank_kwargs()
|
||||||
|
exclude_kwargs = self._get_kwargs_to_exclude_from_cache_key()
|
||||||
|
|
||||||
|
combined_kwargs = chat_completion_kwargs.union(
|
||||||
|
text_completion_kwargs,
|
||||||
|
embedding_kwargs,
|
||||||
|
transcription_kwargs,
|
||||||
|
rerank_kwargs,
|
||||||
|
)
|
||||||
|
combined_kwargs = combined_kwargs.difference(exclude_kwargs)
|
||||||
|
return combined_kwargs
|
||||||
|
|
||||||
|
def _get_litellm_supported_chat_completion_kwargs(self) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Get the litellm supported chat completion kwargs
|
||||||
|
|
||||||
|
This follows the OpenAI API Spec
|
||||||
|
"""
|
||||||
|
all_chat_completion_kwargs = set(
|
||||||
|
CompletionCreateParamsNonStreaming.__annotations__.keys()
|
||||||
|
).union(set(CompletionCreateParamsStreaming.__annotations__.keys()))
|
||||||
|
return all_chat_completion_kwargs
|
||||||
|
|
||||||
|
def _get_litellm_supported_text_completion_kwargs(self) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Get the litellm supported text completion kwargs
|
||||||
|
|
||||||
|
This follows the OpenAI API Spec
|
||||||
|
"""
|
||||||
|
all_text_completion_kwargs = set(
|
||||||
|
TextCompletionCreateParamsNonStreaming.__annotations__.keys()
|
||||||
|
).union(set(TextCompletionCreateParamsStreaming.__annotations__.keys()))
|
||||||
|
return all_text_completion_kwargs
|
||||||
|
|
||||||
|
def _get_litellm_supported_rerank_kwargs(self) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Get the litellm supported rerank kwargs
|
||||||
|
"""
|
||||||
|
return set(RerankRequest.model_fields.keys())
|
||||||
|
|
||||||
|
def _get_litellm_supported_embedding_kwargs(self) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Get the litellm supported embedding kwargs
|
||||||
|
|
||||||
|
This follows the OpenAI API Spec
|
||||||
|
"""
|
||||||
|
return set(EmbeddingCreateParams.__annotations__.keys())
|
||||||
|
|
||||||
|
def _get_litellm_supported_transcription_kwargs(self) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Get the litellm supported transcription kwargs
|
||||||
|
|
||||||
|
This follows the OpenAI API Spec
|
||||||
|
"""
|
||||||
|
return set(TranscriptionCreateParams.__annotations__.keys())
|
||||||
|
|
||||||
|
def _get_kwargs_to_exclude_from_cache_key(self) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Get the kwargs to exclude from the cache key
|
||||||
|
"""
|
||||||
|
return set(["metadata"])
|
||||||
|
|
||||||
|
def _get_hashed_cache_key(self, cache_key: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the hashed cache key for the given cache key.
|
||||||
|
|
||||||
|
Use hashlib to create a sha256 hash of the cache key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_key (str): The cache key to hash.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The hashed cache key.
|
||||||
|
"""
|
||||||
hash_object = hashlib.sha256(cache_key.encode())
|
hash_object = hashlib.sha256(cache_key.encode())
|
||||||
# Hexadecimal representation of the hash
|
# Hexadecimal representation of the hash
|
||||||
hash_hex = hash_object.hexdigest()
|
hash_hex = hash_object.hexdigest()
|
||||||
print_verbose(f"Hashed cache key (SHA-256): {hash_hex}")
|
verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex)
|
||||||
if kwargs.get("metadata", {}).get("redis_namespace", None) is not None:
|
return hash_hex
|
||||||
_namespace = kwargs.get("metadata", {}).get("redis_namespace", None)
|
|
||||||
hash_hex = f"{_namespace}:{hash_hex}"
|
def _add_redis_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str:
|
||||||
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
|
"""
|
||||||
elif self.namespace is not None:
|
If a redis namespace is provided, add it to the cache key
|
||||||
hash_hex = f"{self.namespace}:{hash_hex}"
|
|
||||||
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
|
Args:
|
||||||
|
hash_hex (str): The hashed cache key.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The final hashed cache key with the redis namespace.
|
||||||
|
"""
|
||||||
|
namespace = kwargs.get("metadata", {}).get("redis_namespace") or self.namespace
|
||||||
|
if namespace:
|
||||||
|
hash_hex = f"{namespace}:{hash_hex}"
|
||||||
|
verbose_logger.debug("Final hashed key: %s", hash_hex)
|
||||||
return hash_hex
|
return hash_hex
|
||||||
|
|
||||||
def generate_streaming_content(self, content):
|
def generate_streaming_content(self, content):
|
||||||
|
|
|
@ -182,7 +182,9 @@ class LLMCachingHandler:
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
cache_hit=cache_hit,
|
cache_hit=cache_hit,
|
||||||
)
|
)
|
||||||
cache_key = kwargs.get("preset_cache_key", None)
|
cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
isinstance(cached_result, BaseModel)
|
isinstance(cached_result, BaseModel)
|
||||||
or isinstance(cached_result, CustomStreamWrapper)
|
or isinstance(cached_result, CustomStreamWrapper)
|
||||||
|
@ -236,12 +238,7 @@ class LLMCachingHandler:
|
||||||
original_function=original_function
|
original_function=original_function
|
||||||
):
|
):
|
||||||
print_verbose("Checking Cache")
|
print_verbose("Checking Cache")
|
||||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
|
||||||
kwargs["preset_cache_key"] = (
|
|
||||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
|
||||||
)
|
|
||||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||||
|
|
||||||
if cached_result is not None:
|
if cached_result is not None:
|
||||||
if "detail" in cached_result:
|
if "detail" in cached_result:
|
||||||
# implies an error occurred
|
# implies an error occurred
|
||||||
|
@ -285,7 +282,9 @@ class LLMCachingHandler:
|
||||||
target=logging_obj.success_handler,
|
target=logging_obj.success_handler,
|
||||||
args=(cached_result, start_time, end_time, cache_hit),
|
args=(cached_result, start_time, end_time, cache_hit),
|
||||||
).start()
|
).start()
|
||||||
cache_key = kwargs.get("preset_cache_key", None)
|
cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
isinstance(cached_result, BaseModel)
|
isinstance(cached_result, BaseModel)
|
||||||
or isinstance(cached_result, CustomStreamWrapper)
|
or isinstance(cached_result, CustomStreamWrapper)
|
||||||
|
@ -493,10 +492,6 @@ class LLMCachingHandler:
|
||||||
if all(result is None for result in cached_result):
|
if all(result is None for result in cached_result):
|
||||||
cached_result = None
|
cached_result = None
|
||||||
else:
|
else:
|
||||||
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
|
||||||
kwargs["preset_cache_key"] = (
|
|
||||||
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
|
||||||
)
|
|
||||||
if litellm.cache._supports_async() is True:
|
if litellm.cache._supports_async() is True:
|
||||||
cached_result = await litellm.cache.async_get_cache(*args, **kwargs)
|
cached_result = await litellm.cache.async_get_cache(*args, **kwargs)
|
||||||
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
||||||
|
@ -842,10 +837,16 @@ class LLMCachingHandler:
|
||||||
"metadata": kwargs.get("metadata", {}),
|
"metadata": kwargs.get("metadata", {}),
|
||||||
"model_info": kwargs.get("model_info", {}),
|
"model_info": kwargs.get("model_info", {}),
|
||||||
"proxy_server_request": kwargs.get("proxy_server_request", None),
|
"proxy_server_request": kwargs.get("proxy_server_request", None),
|
||||||
"preset_cache_key": kwargs.get("preset_cache_key", None),
|
|
||||||
"stream_response": kwargs.get("stream_response", {}),
|
"stream_response": kwargs.get("stream_response", {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if litellm.cache is not None:
|
||||||
|
litellm_params["preset_cache_key"] = (
|
||||||
|
litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
litellm_params["preset_cache_key"] = None
|
||||||
|
|
||||||
logging_obj.update_environment_variables(
|
logging_obj.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
user=kwargs.get("user", None),
|
user=kwargs.get("user", None),
|
||||||
|
|
|
@ -664,10 +664,10 @@ class LangFuseLogger:
|
||||||
if "cache_key" in litellm.langfuse_default_tags:
|
if "cache_key" in litellm.langfuse_default_tags:
|
||||||
_hidden_params = metadata.get("hidden_params", {}) or {}
|
_hidden_params = metadata.get("hidden_params", {}) or {}
|
||||||
_cache_key = _hidden_params.get("cache_key", None)
|
_cache_key = _hidden_params.get("cache_key", None)
|
||||||
if _cache_key is None:
|
if _cache_key is None and litellm.cache is not None:
|
||||||
# fallback to using "preset_cache_key"
|
# fallback to using "preset_cache_key"
|
||||||
_preset_cache_key = kwargs.get("litellm_params", {}).get(
|
_preset_cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
|
||||||
"preset_cache_key", None
|
**kwargs
|
||||||
)
|
)
|
||||||
_cache_key = _preset_cache_key
|
_cache_key = _preset_cache_key
|
||||||
tags.append(f"cache_key:{_cache_key}")
|
tags.append(f"cache_key:{_cache_key}")
|
||||||
|
|
|
@ -974,7 +974,7 @@ async def test_redis_cache_acompletion_stream():
|
||||||
response_1_content += chunk.choices[0].delta.content or ""
|
response_1_content += chunk.choices[0].delta.content or ""
|
||||||
print(response_1_content)
|
print(response_1_content)
|
||||||
|
|
||||||
time.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
print("\n\n Response 1 content: ", response_1_content, "\n\n")
|
print("\n\n Response 1 content: ", response_1_content, "\n\n")
|
||||||
|
|
||||||
response2 = await litellm.acompletion(
|
response2 = await litellm.acompletion(
|
||||||
|
|
245
tests/local_testing/test_unit_test_caching.py
Normal file
245
tests/local_testing/test_unit_test_caching.py
Normal file
|
@ -0,0 +1,245 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from test_rerank import assert_response_shape
|
||||||
|
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import random
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import aembedding, completion, embedding
|
||||||
|
from litellm.caching.caching import Cache
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
from litellm.caching.caching_handler import LLMCachingHandler, CachingHandlerResponse
|
||||||
|
from litellm.caching.caching import LiteLLMCacheType
|
||||||
|
from litellm.types.utils import CallTypes
|
||||||
|
from litellm.types.rerank import RerankResponse
|
||||||
|
from litellm.types.utils import (
|
||||||
|
ModelResponse,
|
||||||
|
EmbeddingResponse,
|
||||||
|
TextCompletionResponse,
|
||||||
|
TranscriptionResponse,
|
||||||
|
Embedding,
|
||||||
|
)
|
||||||
|
from datetime import timedelta, datetime
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_kwargs_for_cache_key():
|
||||||
|
_cache = litellm.Cache()
|
||||||
|
relevant_kwargs = _cache._get_relevant_args_to_use_for_cache_key()
|
||||||
|
print(relevant_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cache_key_chat_completion():
|
||||||
|
cache = Cache()
|
||||||
|
kwargs = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||||
|
"temperature": 0.7,
|
||||||
|
}
|
||||||
|
cache_key_1 = cache.get_cache_key(**kwargs)
|
||||||
|
assert isinstance(cache_key_1, str)
|
||||||
|
assert len(cache_key_1) > 0
|
||||||
|
|
||||||
|
kwargs_2 = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||||
|
"max_completion_tokens": 100,
|
||||||
|
}
|
||||||
|
cache_key_2 = cache.get_cache_key(**kwargs_2)
|
||||||
|
assert cache_key_1 != cache_key_2
|
||||||
|
|
||||||
|
kwargs_3 = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||||
|
"max_completion_tokens": 100,
|
||||||
|
}
|
||||||
|
cache_key_3 = cache.get_cache_key(**kwargs_3)
|
||||||
|
assert cache_key_2 == cache_key_3
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cache_key_embedding():
|
||||||
|
cache = Cache()
|
||||||
|
kwargs = {
|
||||||
|
"model": "text-embedding-3-small",
|
||||||
|
"input": "Hello, world!",
|
||||||
|
"dimensions": 1536,
|
||||||
|
}
|
||||||
|
cache_key_1 = cache.get_cache_key(**kwargs)
|
||||||
|
assert isinstance(cache_key_1, str)
|
||||||
|
assert len(cache_key_1) > 0
|
||||||
|
|
||||||
|
kwargs_2 = {
|
||||||
|
"model": "text-embedding-3-small",
|
||||||
|
"input": "Hello, world!",
|
||||||
|
"dimensions": 1539,
|
||||||
|
}
|
||||||
|
cache_key_2 = cache.get_cache_key(**kwargs_2)
|
||||||
|
assert cache_key_1 != cache_key_2
|
||||||
|
|
||||||
|
kwargs_3 = {
|
||||||
|
"model": "text-embedding-3-small",
|
||||||
|
"input": "Hello, world!",
|
||||||
|
"dimensions": 1539,
|
||||||
|
}
|
||||||
|
cache_key_3 = cache.get_cache_key(**kwargs_3)
|
||||||
|
assert cache_key_2 == cache_key_3
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cache_key_text_completion():
|
||||||
|
cache = Cache()
|
||||||
|
kwargs = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"prompt": "Hello, world! here is a second line",
|
||||||
|
"best_of": 3,
|
||||||
|
"logit_bias": {"123": 1},
|
||||||
|
"seed": 42,
|
||||||
|
}
|
||||||
|
cache_key_1 = cache.get_cache_key(**kwargs)
|
||||||
|
assert isinstance(cache_key_1, str)
|
||||||
|
assert len(cache_key_1) > 0
|
||||||
|
|
||||||
|
kwargs_2 = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"prompt": "Hello, world! here is a second line",
|
||||||
|
"best_of": 30,
|
||||||
|
}
|
||||||
|
cache_key_2 = cache.get_cache_key(**kwargs_2)
|
||||||
|
assert cache_key_1 != cache_key_2
|
||||||
|
|
||||||
|
kwargs_3 = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"prompt": "Hello, world! here is a second line",
|
||||||
|
"best_of": 30,
|
||||||
|
}
|
||||||
|
cache_key_3 = cache.get_cache_key(**kwargs_3)
|
||||||
|
assert cache_key_2 == cache_key_3
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_hashed_cache_key():
|
||||||
|
cache = Cache()
|
||||||
|
cache_key = "model:gpt-3.5-turbo,messages:Hello world"
|
||||||
|
hashed_key = cache._get_hashed_cache_key(cache_key)
|
||||||
|
assert len(hashed_key) == 64 # SHA-256 produces a 64-character hex string
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_redis_namespace_to_cache_key():
|
||||||
|
cache = Cache(namespace="test_namespace")
|
||||||
|
hashed_key = "abcdef1234567890"
|
||||||
|
|
||||||
|
# Test with class-level namespace
|
||||||
|
result = cache._add_redis_namespace_to_cache_key(hashed_key)
|
||||||
|
assert result == "test_namespace:abcdef1234567890"
|
||||||
|
|
||||||
|
# Test with metadata namespace
|
||||||
|
kwargs = {"metadata": {"redis_namespace": "custom_namespace"}}
|
||||||
|
result = cache._add_redis_namespace_to_cache_key(hashed_key, **kwargs)
|
||||||
|
assert result == "custom_namespace:abcdef1234567890"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_model_param_value():
|
||||||
|
cache = Cache()
|
||||||
|
|
||||||
|
# Test with regular model
|
||||||
|
kwargs = {"model": "gpt-3.5-turbo"}
|
||||||
|
assert cache._get_model_param_value(kwargs) == "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
# Test with model_group
|
||||||
|
kwargs = {"model": "gpt-3.5-turbo", "metadata": {"model_group": "gpt-group"}}
|
||||||
|
assert cache._get_model_param_value(kwargs) == "gpt-group"
|
||||||
|
|
||||||
|
# Test with caching_group
|
||||||
|
kwargs = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"metadata": {
|
||||||
|
"model_group": "openai-gpt-3.5-turbo",
|
||||||
|
"caching_groups": [("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert (
|
||||||
|
cache._get_model_param_value(kwargs)
|
||||||
|
== "('openai-gpt-3.5-turbo', 'azure-gpt-3.5-turbo')"
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"metadata": {
|
||||||
|
"model_group": "azure-gpt-3.5-turbo",
|
||||||
|
"caching_groups": [("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert (
|
||||||
|
cache._get_model_param_value(kwargs)
|
||||||
|
== "('openai-gpt-3.5-turbo', 'azure-gpt-3.5-turbo')"
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"metadata": {
|
||||||
|
"model_group": "not-in-caching-group-gpt-3.5-turbo",
|
||||||
|
"caching_groups": [("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert cache._get_model_param_value(kwargs) == "not-in-caching-group-gpt-3.5-turbo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_preset_cache_key():
|
||||||
|
"""
|
||||||
|
Test that the preset cache key is used if it is set in kwargs["litellm_params"]
|
||||||
|
"""
|
||||||
|
cache = Cache()
|
||||||
|
kwargs = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [{"role": "user", "content": "Hello, world!"}],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"litellm_params": {"preset_cache_key": "preset-cache-key"},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert cache.get_cache_key(**kwargs) == "preset-cache-key"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_streaming_content():
|
||||||
|
cache = Cache()
|
||||||
|
content = "Hello, this is a test message."
|
||||||
|
generator = cache.generate_streaming_content(content)
|
||||||
|
|
||||||
|
full_response = ""
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
|
for chunk in generator:
|
||||||
|
chunk_count += 1
|
||||||
|
assert "choices" in chunk
|
||||||
|
assert len(chunk["choices"]) == 1
|
||||||
|
assert "delta" in chunk["choices"][0]
|
||||||
|
assert "role" in chunk["choices"][0]["delta"]
|
||||||
|
assert chunk["choices"][0]["delta"]["role"] == "assistant"
|
||||||
|
assert "content" in chunk["choices"][0]["delta"]
|
||||||
|
|
||||||
|
chunk_content = chunk["choices"][0]["delta"]["content"]
|
||||||
|
full_response += chunk_content
|
||||||
|
|
||||||
|
# Check that each chunk is no longer than 5 characters
|
||||||
|
assert len(chunk_content) <= 5
|
||||||
|
print("full_response from generate_streaming_content", full_response)
|
||||||
|
# Check that the full content is reconstructed correctly
|
||||||
|
assert full_response == content
|
||||||
|
# Check that there were multiple chunks
|
||||||
|
assert chunk_count > 1
|
||||||
|
|
||||||
|
print(f"Number of chunks: {chunk_count}")
|
Loading…
Add table
Add a link
Reference in a new issue