diff --git a/litellm/router.py b/litellm/router.py index ed28d1d482..aeaabac60d 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -470,7 +470,7 @@ class Router: self.default_litellm_params = default_litellm_params self.default_litellm_params.setdefault("timeout", timeout) self.default_litellm_params.setdefault("max_retries", 0) - self.default_litellm_params.setdefault("litellm_metadata", {}).update( + self.default_litellm_params.setdefault("metadata", {}).update( {"caching_groups": caching_groups} ) @@ -1086,17 +1086,22 @@ class Router: kwargs.setdefault("litellm_trace_id", str(uuid.uuid4())) kwargs.setdefault("metadata", {}).update({"model_group": model}) - def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None: + def _update_kwargs_with_default_litellm_params( + self, kwargs: dict, metadata_variable_name: str + ) -> None: """ Adds default litellm params to kwargs, if set. """ + self.default_litellm_params[metadata_variable_name] = ( + self.default_litellm_params.pop("metadata") + ) for k, v in self.default_litellm_params.items(): if ( k not in kwargs and v is not None ): # prioritize model-specific params > default router params kwargs[k] = v - elif k == "metadata": - kwargs[k].update(v) + elif k == metadata_variable_name: + kwargs[metadata_variable_name].update(v) def _handle_clientside_credential( self, deployment: dict, kwargs: dict @@ -1165,7 +1170,9 @@ class Router: kwargs=kwargs, data=deployment["litellm_params"] ) - self._update_kwargs_with_default_litellm_params(kwargs=kwargs) + self._update_kwargs_with_default_litellm_params( + kwargs=kwargs, metadata_variable_name=metadata_variable_name + ) def _get_async_openai_model_client(self, deployment: dict, kwargs: dict): """ @@ -2416,18 +2423,12 @@ class Router: data = deployment["litellm_params"].copy() model_name = data["model"] - - model_client = self._get_async_openai_model_client( - deployment=deployment, - kwargs=kwargs, - ) self.total_calls[model_name] += 1 response = original_function( **{ **data, "caching": self.cache_responses, - "client": model_client, **kwargs, } ) @@ -2498,9 +2499,6 @@ class Router: data = deployment["litellm_params"].copy() model_name = data["model"] - model_client = self._get_client( - deployment=deployment, kwargs=kwargs, client_type="sync" - ) self.total_calls[model_name] += 1 # Perform pre-call checks for routing strategy @@ -2510,7 +2508,6 @@ class Router: **{ **data, "caching": self.cache_responses, - "client": model_client, **kwargs, } )