_update_kwargs_with_default_litellm_params

This commit is contained in:
Ishaan Jaff 2025-03-12 18:33:56 -07:00
parent 168ade935e
commit eaca45cb05

View file

@ -470,7 +470,7 @@ class Router:
self.default_litellm_params = default_litellm_params self.default_litellm_params = default_litellm_params
self.default_litellm_params.setdefault("timeout", timeout) self.default_litellm_params.setdefault("timeout", timeout)
self.default_litellm_params.setdefault("max_retries", 0) self.default_litellm_params.setdefault("max_retries", 0)
self.default_litellm_params.setdefault("litellm_metadata", {}).update( self.default_litellm_params.setdefault("metadata", {}).update(
{"caching_groups": caching_groups} {"caching_groups": caching_groups}
) )
@ -1086,17 +1086,22 @@ class Router:
kwargs.setdefault("litellm_trace_id", str(uuid.uuid4())) kwargs.setdefault("litellm_trace_id", str(uuid.uuid4()))
kwargs.setdefault("metadata", {}).update({"model_group": model}) kwargs.setdefault("metadata", {}).update({"model_group": model})
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None: def _update_kwargs_with_default_litellm_params(
self, kwargs: dict, metadata_variable_name: str
) -> None:
""" """
Adds default litellm params to kwargs, if set. Adds default litellm params to kwargs, if set.
""" """
self.default_litellm_params[metadata_variable_name] = (
self.default_litellm_params.pop("metadata")
)
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
if ( if (
k not in kwargs and v is not None k not in kwargs and v is not None
): # prioritize model-specific params > default router params ): # prioritize model-specific params > default router params
kwargs[k] = v kwargs[k] = v
elif k == "metadata": elif k == metadata_variable_name:
kwargs[k].update(v) kwargs[metadata_variable_name].update(v)
def _handle_clientside_credential( def _handle_clientside_credential(
self, deployment: dict, kwargs: dict self, deployment: dict, kwargs: dict
@ -1165,7 +1170,9 @@ class Router:
kwargs=kwargs, data=deployment["litellm_params"] kwargs=kwargs, data=deployment["litellm_params"]
) )
self._update_kwargs_with_default_litellm_params(kwargs=kwargs) self._update_kwargs_with_default_litellm_params(
kwargs=kwargs, metadata_variable_name=metadata_variable_name
)
def _get_async_openai_model_client(self, deployment: dict, kwargs: dict): def _get_async_openai_model_client(self, deployment: dict, kwargs: dict):
""" """
@ -2416,18 +2423,12 @@ class Router:
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
model_name = data["model"] model_name = data["model"]
model_client = self._get_async_openai_model_client(
deployment=deployment,
kwargs=kwargs,
)
self.total_calls[model_name] += 1 self.total_calls[model_name] += 1
response = original_function( response = original_function(
**{ **{
**data, **data,
"caching": self.cache_responses, "caching": self.cache_responses,
"client": model_client,
**kwargs, **kwargs,
} }
) )
@ -2498,9 +2499,6 @@ class Router:
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
model_name = data["model"] model_name = data["model"]
model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="sync"
)
self.total_calls[model_name] += 1 self.total_calls[model_name] += 1
# Perform pre-call checks for routing strategy # Perform pre-call checks for routing strategy
@ -2510,7 +2508,6 @@ class Router:
**{ **{
**data, **data,
"caching": self.cache_responses, "caching": self.cache_responses,
"client": model_client,
**kwargs, **kwargs,
} }
) )