_update_kwargs_with_default_litellm_params

This commit is contained in:
Ishaan Jaff 2025-03-12 18:33:56 -07:00
parent 168ade935e
commit eaca45cb05

View file

@ -470,7 +470,7 @@ class Router:
self.default_litellm_params = default_litellm_params
self.default_litellm_params.setdefault("timeout", timeout)
self.default_litellm_params.setdefault("max_retries", 0)
self.default_litellm_params.setdefault("litellm_metadata", {}).update(
self.default_litellm_params.setdefault("metadata", {}).update(
{"caching_groups": caching_groups}
)
@ -1086,17 +1086,22 @@ class Router:
kwargs.setdefault("litellm_trace_id", str(uuid.uuid4()))
kwargs.setdefault("metadata", {}).update({"model_group": model})
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None:
def _update_kwargs_with_default_litellm_params(
self, kwargs: dict, metadata_variable_name: str
) -> None:
"""
Adds default litellm params to kwargs, if set.
"""
self.default_litellm_params[metadata_variable_name] = (
self.default_litellm_params.pop("metadata")
)
for k, v in self.default_litellm_params.items():
if (
k not in kwargs and v is not None
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
elif k == metadata_variable_name:
kwargs[metadata_variable_name].update(v)
def _handle_clientside_credential(
self, deployment: dict, kwargs: dict
@ -1165,7 +1170,9 @@ class Router:
kwargs=kwargs, data=deployment["litellm_params"]
)
self._update_kwargs_with_default_litellm_params(kwargs=kwargs)
self._update_kwargs_with_default_litellm_params(
kwargs=kwargs, metadata_variable_name=metadata_variable_name
)
def _get_async_openai_model_client(self, deployment: dict, kwargs: dict):
"""
@ -2416,18 +2423,12 @@ class Router:
data = deployment["litellm_params"].copy()
model_name = data["model"]
model_client = self._get_async_openai_model_client(
deployment=deployment,
kwargs=kwargs,
)
self.total_calls[model_name] += 1
response = original_function(
**{
**data,
"caching": self.cache_responses,
"client": model_client,
**kwargs,
}
)
@ -2498,9 +2499,6 @@ class Router:
data = deployment["litellm_params"].copy()
model_name = data["model"]
model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="sync"
)
self.total_calls[model_name] += 1
# Perform pre-call checks for routing strategy
@ -2510,7 +2508,6 @@ class Router:
**{
**data,
"caching": self.cache_responses,
"client": model_client,
**kwargs,
}
)