diff --git a/litellm/caching.py b/litellm/caching.py index 8fa8678a1d..d14ef8a657 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -206,10 +206,13 @@ class Cache: cache_key ="" for param in kwargs: # ignore litellm params here - if param in set(["litellm_call_id", "litellm_logging_obj"]): - continue - param_value = kwargs[param] - cache_key+= f"{str(param)}: {str(param_value)}" + if param in set(["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"]): + # check if param == model and model_group is passed in, then override model with model_group + if param == "model" and kwargs.get("metadata", None) is not None and kwargs["metadata"].get("model_group", None) is not None: + param_value = kwargs["metadata"].get("model_group", None) # for litellm.Router use model_group for caching over `model` + else: + param_value = kwargs[param] + cache_key+= f"{str(param)}: {str(param_value)}" return cache_key def generate_streaming_content(self, content):