mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
(refactor) get_cache_key
to be under 100 LOC function (#6327)
* refactor - use helpers for name space and hashing * use openai to get the relevant supported params * use helpers for getting cache key * fix test caching * use get/set helpers for preset cache keys * make get_cache_key under 100 LOC * fix _get_model_param_value * fix _get_caching_group * fix linting error * add unit testing for get cache key * test_generate_streaming_content
This commit is contained in:
parent
10aaa419a0
commit
7c85aa753f
5 changed files with 477 additions and 124 deletions
|
@ -17,13 +17,26 @@ import logging
|
|||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import Any, List, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||
|
||||
from openai.types.audio.transcription_create_params import TranscriptionCreateParams
|
||||
from openai.types.chat.completion_create_params import (
|
||||
CompletionCreateParamsNonStreaming,
|
||||
CompletionCreateParamsStreaming,
|
||||
)
|
||||
from openai.types.completion_create_params import (
|
||||
CompletionCreateParamsNonStreaming as TextCompletionCreateParamsNonStreaming,
|
||||
)
|
||||
from openai.types.completion_create_params import (
|
||||
CompletionCreateParamsStreaming as TextCompletionCreateParamsStreaming,
|
||||
)
|
||||
from openai.types.embedding_create_params import EmbeddingCreateParams
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.caching import *
|
||||
from litellm.types.rerank import RerankRequest
|
||||
from litellm.types.utils import all_litellm_params
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
@ -220,7 +233,7 @@ class Cache:
|
|||
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
||||
self.cache.namespace = self.namespace
|
||||
|
||||
def get_cache_key(self, *args, **kwargs) -> str: # noqa: PLR0915
|
||||
def get_cache_key(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
Get the cache key for the given arguments.
|
||||
|
||||
|
@ -232,107 +245,20 @@ class Cache:
|
|||
str: The cache key generated from the arguments, or None if no cache key could be generated.
|
||||
"""
|
||||
cache_key = ""
|
||||
print_verbose(f"\nGetting Cache key. Kwargs: {kwargs}")
|
||||
verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
|
||||
|
||||
# for streaming, we use preset_cache_key. It's created in wrapper(), we do this because optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
||||
if kwargs.get("litellm_params", {}).get("preset_cache_key", None) is not None:
|
||||
_preset_cache_key = kwargs.get("litellm_params", {}).get(
|
||||
"preset_cache_key", None
|
||||
)
|
||||
print_verbose(f"\nReturning preset cache key: {_preset_cache_key}")
|
||||
return _preset_cache_key
|
||||
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
|
||||
if preset_cache_key is not None:
|
||||
verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key)
|
||||
return preset_cache_key
|
||||
|
||||
# sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4]
|
||||
completion_kwargs = [
|
||||
"model",
|
||||
"messages",
|
||||
"prompt",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"n",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"response_format",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"stream",
|
||||
]
|
||||
embedding_only_kwargs = [
|
||||
"input",
|
||||
"encoding_format",
|
||||
] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
|
||||
transcription_only_kwargs = [
|
||||
"file",
|
||||
"language",
|
||||
]
|
||||
rerank_only_kwargs = [
|
||||
"top_n",
|
||||
"rank_fields",
|
||||
"return_documents",
|
||||
"max_chunks_per_doc",
|
||||
"documents",
|
||||
"query",
|
||||
]
|
||||
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
|
||||
combined_kwargs = (
|
||||
completion_kwargs
|
||||
+ embedding_only_kwargs
|
||||
+ transcription_only_kwargs
|
||||
+ rerank_only_kwargs
|
||||
)
|
||||
combined_kwargs = self._get_relevant_args_to_use_for_cache_key()
|
||||
litellm_param_kwargs = all_litellm_params
|
||||
for param in kwargs:
|
||||
if param in combined_kwargs:
|
||||
# check if param == model and model_group is passed in, then override model with model_group
|
||||
if param == "model":
|
||||
model_group = None
|
||||
caching_group = None
|
||||
metadata = kwargs.get("metadata", None)
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
if metadata is not None:
|
||||
model_group = metadata.get("model_group")
|
||||
model_group = metadata.get("model_group", None)
|
||||
caching_groups = metadata.get("caching_groups", None)
|
||||
if caching_groups:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
caching_group = group
|
||||
break
|
||||
if litellm_params is not None:
|
||||
metadata = litellm_params.get("metadata", None)
|
||||
if metadata is not None:
|
||||
model_group = metadata.get("model_group", None)
|
||||
caching_groups = metadata.get("caching_groups", None)
|
||||
if caching_groups:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
caching_group = group
|
||||
break
|
||||
param_value = (
|
||||
caching_group or model_group or kwargs[param]
|
||||
) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
|
||||
elif param == "file":
|
||||
file = kwargs.get("file")
|
||||
metadata = kwargs.get("metadata", {})
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
|
||||
# get checksum of file content
|
||||
param_value = (
|
||||
metadata.get("file_checksum")
|
||||
or getattr(file, "name", None)
|
||||
or metadata.get("file_name")
|
||||
or litellm_params.get("file_name")
|
||||
)
|
||||
else:
|
||||
if kwargs[param] is None:
|
||||
continue # ignore None params
|
||||
param_value = kwargs[param]
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
param_value: Optional[str] = self._get_param_value(param, kwargs)
|
||||
if param_value is not None:
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
elif (
|
||||
param not in litellm_param_kwargs
|
||||
): # check if user passed in optional param - e.g. top_k
|
||||
|
@ -344,19 +270,200 @@ class Cache:
|
|||
param_value = kwargs[param]
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
|
||||
print_verbose(f"\nCreated cache key: {cache_key}")
|
||||
# Use hashlib to create a sha256 hash of the cache key
|
||||
verbose_logger.debug("\nCreated cache key: %s", cache_key)
|
||||
hashed_cache_key = self._get_hashed_cache_key(cache_key)
|
||||
hashed_cache_key = self._add_redis_namespace_to_cache_key(
|
||||
hashed_cache_key, **kwargs
|
||||
)
|
||||
self._set_preset_cache_key_in_kwargs(
|
||||
preset_cache_key=hashed_cache_key, **kwargs
|
||||
)
|
||||
return hashed_cache_key
|
||||
|
||||
def _get_param_value(
|
||||
self,
|
||||
param: str,
|
||||
kwargs: dict,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the value for the given param from kwargs
|
||||
"""
|
||||
if param == "model":
|
||||
return self._get_model_param_value(kwargs)
|
||||
elif param == "file":
|
||||
return self._get_file_param_value(kwargs)
|
||||
return kwargs[param]
|
||||
|
||||
def _get_model_param_value(self, kwargs: dict) -> str:
|
||||
"""
|
||||
Handles getting the value for the 'model' param from kwargs
|
||||
|
||||
1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups
|
||||
2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router()
|
||||
3. Else use the `model` passed in kwargs
|
||||
"""
|
||||
metadata: Dict = kwargs.get("metadata", {}) or {}
|
||||
litellm_params: Dict = kwargs.get("litellm_params", {}) or {}
|
||||
metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {}
|
||||
model_group: Optional[str] = metadata.get(
|
||||
"model_group"
|
||||
) or metadata_in_litellm_params.get("model_group")
|
||||
caching_group = self._get_caching_group(metadata, model_group)
|
||||
return caching_group or model_group or kwargs["model"]
|
||||
|
||||
def _get_caching_group(
|
||||
self, metadata: dict, model_group: Optional[str]
|
||||
) -> Optional[str]:
|
||||
caching_groups: Optional[List] = metadata.get("caching_groups", [])
|
||||
if caching_groups:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
return str(group)
|
||||
return None
|
||||
|
||||
def _get_file_param_value(self, kwargs: dict) -> str:
|
||||
"""
|
||||
Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests
|
||||
"""
|
||||
file = kwargs.get("file")
|
||||
metadata = kwargs.get("metadata", {})
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
return (
|
||||
metadata.get("file_checksum")
|
||||
or getattr(file, "name", None)
|
||||
or metadata.get("file_name")
|
||||
or litellm_params.get("file_name")
|
||||
)
|
||||
|
||||
def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
Get the preset cache key from kwargs["litellm_params"]
|
||||
|
||||
We use _get_preset_cache_keys for two reasons
|
||||
|
||||
1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
||||
2. avoid doing duplicate / repeated work
|
||||
"""
|
||||
if kwargs:
|
||||
if "litellm_params" in kwargs:
|
||||
return kwargs["litellm_params"].get("preset_cache_key", None)
|
||||
return None
|
||||
|
||||
def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None:
|
||||
"""
|
||||
Set the calculated cache key in kwargs
|
||||
|
||||
This is used to avoid doing duplicate / repeated work
|
||||
|
||||
Placed in kwargs["litellm_params"]
|
||||
"""
|
||||
if kwargs:
|
||||
if "litellm_params" in kwargs:
|
||||
kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key
|
||||
|
||||
def _get_relevant_args_to_use_for_cache_key(self) -> Set[str]:
|
||||
"""
|
||||
Gets the supported kwargs for each call type and combines them
|
||||
"""
|
||||
chat_completion_kwargs = self._get_litellm_supported_chat_completion_kwargs()
|
||||
text_completion_kwargs = self._get_litellm_supported_text_completion_kwargs()
|
||||
embedding_kwargs = self._get_litellm_supported_embedding_kwargs()
|
||||
transcription_kwargs = self._get_litellm_supported_transcription_kwargs()
|
||||
rerank_kwargs = self._get_litellm_supported_rerank_kwargs()
|
||||
exclude_kwargs = self._get_kwargs_to_exclude_from_cache_key()
|
||||
|
||||
combined_kwargs = chat_completion_kwargs.union(
|
||||
text_completion_kwargs,
|
||||
embedding_kwargs,
|
||||
transcription_kwargs,
|
||||
rerank_kwargs,
|
||||
)
|
||||
combined_kwargs = combined_kwargs.difference(exclude_kwargs)
|
||||
return combined_kwargs
|
||||
|
||||
def _get_litellm_supported_chat_completion_kwargs(self) -> Set[str]:
|
||||
"""
|
||||
Get the litellm supported chat completion kwargs
|
||||
|
||||
This follows the OpenAI API Spec
|
||||
"""
|
||||
all_chat_completion_kwargs = set(
|
||||
CompletionCreateParamsNonStreaming.__annotations__.keys()
|
||||
).union(set(CompletionCreateParamsStreaming.__annotations__.keys()))
|
||||
return all_chat_completion_kwargs
|
||||
|
||||
def _get_litellm_supported_text_completion_kwargs(self) -> Set[str]:
|
||||
"""
|
||||
Get the litellm supported text completion kwargs
|
||||
|
||||
This follows the OpenAI API Spec
|
||||
"""
|
||||
all_text_completion_kwargs = set(
|
||||
TextCompletionCreateParamsNonStreaming.__annotations__.keys()
|
||||
).union(set(TextCompletionCreateParamsStreaming.__annotations__.keys()))
|
||||
return all_text_completion_kwargs
|
||||
|
||||
def _get_litellm_supported_rerank_kwargs(self) -> Set[str]:
|
||||
"""
|
||||
Get the litellm supported rerank kwargs
|
||||
"""
|
||||
return set(RerankRequest.model_fields.keys())
|
||||
|
||||
def _get_litellm_supported_embedding_kwargs(self) -> Set[str]:
|
||||
"""
|
||||
Get the litellm supported embedding kwargs
|
||||
|
||||
This follows the OpenAI API Spec
|
||||
"""
|
||||
return set(EmbeddingCreateParams.__annotations__.keys())
|
||||
|
||||
def _get_litellm_supported_transcription_kwargs(self) -> Set[str]:
|
||||
"""
|
||||
Get the litellm supported transcription kwargs
|
||||
|
||||
This follows the OpenAI API Spec
|
||||
"""
|
||||
return set(TranscriptionCreateParams.__annotations__.keys())
|
||||
|
||||
def _get_kwargs_to_exclude_from_cache_key(self) -> Set[str]:
|
||||
"""
|
||||
Get the kwargs to exclude from the cache key
|
||||
"""
|
||||
return set(["metadata"])
|
||||
|
||||
def _get_hashed_cache_key(self, cache_key: str) -> str:
|
||||
"""
|
||||
Get the hashed cache key for the given cache key.
|
||||
|
||||
Use hashlib to create a sha256 hash of the cache key
|
||||
|
||||
Args:
|
||||
cache_key (str): The cache key to hash.
|
||||
|
||||
Returns:
|
||||
str: The hashed cache key.
|
||||
"""
|
||||
hash_object = hashlib.sha256(cache_key.encode())
|
||||
# Hexadecimal representation of the hash
|
||||
hash_hex = hash_object.hexdigest()
|
||||
print_verbose(f"Hashed cache key (SHA-256): {hash_hex}")
|
||||
if kwargs.get("metadata", {}).get("redis_namespace", None) is not None:
|
||||
_namespace = kwargs.get("metadata", {}).get("redis_namespace", None)
|
||||
hash_hex = f"{_namespace}:{hash_hex}"
|
||||
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
|
||||
elif self.namespace is not None:
|
||||
hash_hex = f"{self.namespace}:{hash_hex}"
|
||||
print_verbose(f"Hashed Key with Namespace: {hash_hex}")
|
||||
verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex)
|
||||
return hash_hex
|
||||
|
||||
def _add_redis_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str:
|
||||
"""
|
||||
If a redis namespace is provided, add it to the cache key
|
||||
|
||||
Args:
|
||||
hash_hex (str): The hashed cache key.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The final hashed cache key with the redis namespace.
|
||||
"""
|
||||
namespace = kwargs.get("metadata", {}).get("redis_namespace") or self.namespace
|
||||
if namespace:
|
||||
hash_hex = f"{namespace}:{hash_hex}"
|
||||
verbose_logger.debug("Final hashed key: %s", hash_hex)
|
||||
return hash_hex
|
||||
|
||||
def generate_streaming_content(self, content):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue