mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(router.py): fix fallbacks
This commit is contained in:
parent
cc3d7da9a0
commit
59ba1560e5
3 changed files with 46 additions and 9 deletions
|
@ -144,12 +144,13 @@ class Router:
|
|||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["original_function"] = self._completion
|
||||
kwargs["num_retries"] = self.num_retries
|
||||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
# Submit the function to the executor with a timeout
|
||||
future = executor.submit(self.function_with_fallbacks, **kwargs)
|
||||
response = future.result(timeout=self.timeout) # type: ignore
|
||||
response = future.result(timeout=timeout) # type: ignore
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
|
@ -164,6 +165,7 @@ class Router:
|
|||
try:
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||
data = deployment["litellm_params"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
|
@ -182,9 +184,10 @@ class Router:
|
|||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["original_function"] = self._acompletion
|
||||
kwargs["num_retries"] = self.num_retries
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = await asyncio.wait_for(self.async_function_with_fallbacks(**kwargs), timeout=self.timeout)
|
||||
response = await asyncio.wait_for(self.async_function_with_fallbacks(**kwargs), timeout=timeout)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
|
@ -196,7 +199,9 @@ class Router:
|
|||
messages: List[Dict[str, str]],
|
||||
**kwargs):
|
||||
try:
|
||||
self.print_verbose(f"Inside _acompletion()- model: {model}; kwargs: {kwargs}")
|
||||
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||
data = deployment["litellm_params"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
|
@ -244,7 +249,7 @@ class Router:
|
|||
**kwargs) -> Union[List[float], None]:
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(model=model, input=input)
|
||||
|
||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||
data = deployment["litellm_params"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
|
@ -259,7 +264,7 @@ class Router:
|
|||
**kwargs) -> Union[List[float], None]:
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(model=model, input=input)
|
||||
|
||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||
data = deployment["litellm_params"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
|
@ -315,10 +320,11 @@ class Router:
|
|||
"""
|
||||
try:
|
||||
kwargs["model"] = mg
|
||||
kwargs["metadata"]["model_group"] = mg
|
||||
response = await self.async_function_with_retries(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
pass
|
||||
raise e
|
||||
except Exception as e:
|
||||
self.print_verbose(f"An exception occurred - {str(e)}")
|
||||
traceback.print_exc()
|
||||
|
@ -328,12 +334,14 @@ class Router:
|
|||
self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}")
|
||||
backoff_factor = 1
|
||||
original_function = kwargs.pop("original_function")
|
||||
self.print_verbose(f"async function w/ retries: original_function - {original_function}")
|
||||
num_retries = kwargs.pop("num_retries")
|
||||
try:
|
||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
response = await original_function(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
original_exception = e
|
||||
for current_attempt in range(num_retries):
|
||||
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
|
||||
try:
|
||||
|
@ -359,7 +367,7 @@ class Router:
|
|||
pass
|
||||
else:
|
||||
raise e
|
||||
raise e
|
||||
raise original_exception
|
||||
|
||||
def function_with_fallbacks(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -499,10 +507,13 @@ class Router:
|
|||
try:
|
||||
model_name = kwargs.get('model', None) # i.e. gpt35turbo
|
||||
custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure
|
||||
metadata = kwargs.get("litellm_params", {}).get('metadata', None)
|
||||
if metadata:
|
||||
deployment = metadata.get("deployment", None)
|
||||
self._set_cooldown_deployments(deployment)
|
||||
if custom_llm_provider:
|
||||
model_name = f"{custom_llm_provider}/{model_name}"
|
||||
|
||||
self._set_cooldown_deployments(model_name)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -731,6 +742,10 @@ class Router:
|
|||
### get all deployments
|
||||
### filter out the deployments currently cooling down
|
||||
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
||||
if len(healthy_deployments) == 0:
|
||||
# check if the user sent in a deployment name instead
|
||||
healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model]
|
||||
self.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
||||
deployments_to_remove = []
|
||||
cooldown_deployments = self._get_cooldown_deployments()
|
||||
self.print_verbose(f"cooldown deployments: {cooldown_deployments}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue