From ca852e1dcdc3fde2ff396e3ac4a7d5e4bbe5991c Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 23 Nov 2023 20:56:09 -0800 Subject: [PATCH] (fix) caching use model, messages, temp, max_tokens as cache_key --- litellm/caching.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/litellm/caching.py b/litellm/caching.py index 8fa8678a1d..d14ef8a657 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -206,10 +206,13 @@ class Cache: cache_key ="" for param in kwargs: # ignore litellm params here - if param in set(["litellm_call_id", "litellm_logging_obj"]): - continue - param_value = kwargs[param] - cache_key+= f"{str(param)}: {str(param_value)}" + if param in set(["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"]): + # check if param == model and model_group is passed in, then override model with model_group + if param == "model" and kwargs.get("metadata", None) is not None and kwargs["metadata"].get("model_group", None) is not None: + param_value = kwargs["metadata"].get("model_group", None) # for litellm.Router use model_group for caching over `model` + else: + param_value = kwargs[param] + cache_key+= f"{str(param)}: {str(param_value)}" return cache_key def generate_streaming_content(self, content):