diff --git a/litellm/caching.py b/litellm/caching.py index cbaa3bb20..1d993927b 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -230,8 +230,15 @@ class Cache: if param in kwargs: # check if param == model and model_group is passed in, then override model with model_group if param == "model": - # for litellm.Router use model_group for caching over `model` - model_group = kwargs.get("metadata", {}).get("model_group", None) or kwargs.get("litellm_params", {}).get("metadata", {}).get("model_group", None) + model_group = None + metadata = kwargs.get("metadata", None) + litellm_params = kwargs.get("litellm_params", {}) + if metadata is not None: + model_group = metadata.get("model_group") + if litellm_params is not None: + metadata = litellm_params.get("metadata", None) + if metadata is not None: + model_group = metadata.get("model_group", None) param_value = model_group or kwargs[param] # use model_group if it exists, else use kwargs["model"] else: if kwargs[param] is None: