mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
_update_kwargs_with_default_litellm_params
This commit is contained in:
parent
168ade935e
commit
eaca45cb05
1 changed files with 12 additions and 15 deletions
|
@ -470,7 +470,7 @@ class Router:
|
|||
self.default_litellm_params = default_litellm_params
|
||||
self.default_litellm_params.setdefault("timeout", timeout)
|
||||
self.default_litellm_params.setdefault("max_retries", 0)
|
||||
self.default_litellm_params.setdefault("litellm_metadata", {}).update(
|
||||
self.default_litellm_params.setdefault("metadata", {}).update(
|
||||
{"caching_groups": caching_groups}
|
||||
)
|
||||
|
||||
|
@ -1086,17 +1086,22 @@ class Router:
|
|||
kwargs.setdefault("litellm_trace_id", str(uuid.uuid4()))
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
|
||||
def _update_kwargs_with_default_litellm_params(self, kwargs: dict) -> None:
|
||||
def _update_kwargs_with_default_litellm_params(
|
||||
self, kwargs: dict, metadata_variable_name: str
|
||||
) -> None:
|
||||
"""
|
||||
Adds default litellm params to kwargs, if set.
|
||||
"""
|
||||
self.default_litellm_params[metadata_variable_name] = (
|
||||
self.default_litellm_params.pop("metadata")
|
||||
)
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if (
|
||||
k not in kwargs and v is not None
|
||||
): # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
elif k == metadata_variable_name:
|
||||
kwargs[metadata_variable_name].update(v)
|
||||
|
||||
def _handle_clientside_credential(
|
||||
self, deployment: dict, kwargs: dict
|
||||
|
@ -1165,7 +1170,9 @@ class Router:
|
|||
kwargs=kwargs, data=deployment["litellm_params"]
|
||||
)
|
||||
|
||||
self._update_kwargs_with_default_litellm_params(kwargs=kwargs)
|
||||
self._update_kwargs_with_default_litellm_params(
|
||||
kwargs=kwargs, metadata_variable_name=metadata_variable_name
|
||||
)
|
||||
|
||||
def _get_async_openai_model_client(self, deployment: dict, kwargs: dict):
|
||||
"""
|
||||
|
@ -2416,18 +2423,12 @@ class Router:
|
|||
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
|
||||
model_client = self._get_async_openai_model_client(
|
||||
deployment=deployment,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
response = original_function(
|
||||
**{
|
||||
**data,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
@ -2498,9 +2499,6 @@ class Router:
|
|||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
|
||||
model_client = self._get_client(
|
||||
deployment=deployment, kwargs=kwargs, client_type="sync"
|
||||
)
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
# Perform pre-call checks for routing strategy
|
||||
|
@ -2510,7 +2508,6 @@ class Router:
|
|||
**{
|
||||
**data,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue