mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(router.py): fix fallbacks
This commit is contained in:
parent
e0fe4a13a2
commit
2680f84cfc
3 changed files with 46 additions and 9 deletions
|
@ -839,6 +839,7 @@ If the error is a context window exceeded error, fall back to a larger model gro
|
||||||
|
|
||||||
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/router.py)
|
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/router.py)
|
||||||
|
|
||||||
|
**Set via config**
|
||||||
```yaml
|
```yaml
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: zephyr-beta
|
- model_name: zephyr-beta
|
||||||
|
@ -870,6 +871,26 @@ litellm_settings:
|
||||||
allowed_fails: 3 # cooldown model if it fails > 1 call in a minute.
|
allowed_fails: 3 # cooldown model if it fails > 1 call in a minute.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Set dynamically**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'http://0.0.0.0:8000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data ' {
|
||||||
|
"model": "zephyr-beta",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what llm are you"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
|
||||||
|
"context_window_fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
|
||||||
|
"num_retries": 2,
|
||||||
|
"request_timeout": 10
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
### Config for Embedding Models - xorbitsai/inference
|
### Config for Embedding Models - xorbitsai/inference
|
||||||
|
|
||||||
|
|
|
@ -144,12 +144,13 @@ class Router:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["messages"] = messages
|
kwargs["messages"] = messages
|
||||||
kwargs["original_function"] = self._completion
|
kwargs["original_function"] = self._completion
|
||||||
kwargs["num_retries"] = self.num_retries
|
timeout = kwargs.get("request_timeout", self.timeout)
|
||||||
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
# Submit the function to the executor with a timeout
|
# Submit the function to the executor with a timeout
|
||||||
future = executor.submit(self.function_with_fallbacks, **kwargs)
|
future = executor.submit(self.function_with_fallbacks, **kwargs)
|
||||||
response = future.result(timeout=self.timeout) # type: ignore
|
response = future.result(timeout=timeout) # type: ignore
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -164,6 +165,7 @@ class Router:
|
||||||
try:
|
try:
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
deployment = self.get_available_deployment(model=model, messages=messages)
|
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||||
|
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||||
data = deployment["litellm_params"]
|
data = deployment["litellm_params"]
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in data: # prioritize model-specific params > default router params
|
||||||
|
@ -182,9 +184,10 @@ class Router:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["messages"] = messages
|
kwargs["messages"] = messages
|
||||||
kwargs["original_function"] = self._acompletion
|
kwargs["original_function"] = self._acompletion
|
||||||
kwargs["num_retries"] = self.num_retries
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
|
timeout = kwargs.get("request_timeout", self.timeout)
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
response = await asyncio.wait_for(self.async_function_with_fallbacks(**kwargs), timeout=self.timeout)
|
response = await asyncio.wait_for(self.async_function_with_fallbacks(**kwargs), timeout=timeout)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -196,7 +199,9 @@ class Router:
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
**kwargs):
|
**kwargs):
|
||||||
try:
|
try:
|
||||||
|
self.print_verbose(f"Inside _acompletion()- model: {model}; kwargs: {kwargs}")
|
||||||
deployment = self.get_available_deployment(model=model, messages=messages)
|
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||||
|
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||||
data = deployment["litellm_params"]
|
data = deployment["litellm_params"]
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in data: # prioritize model-specific params > default router params
|
||||||
|
@ -244,7 +249,7 @@ class Router:
|
||||||
**kwargs) -> Union[List[float], None]:
|
**kwargs) -> Union[List[float], None]:
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
deployment = self.get_available_deployment(model=model, input=input)
|
deployment = self.get_available_deployment(model=model, input=input)
|
||||||
|
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||||
data = deployment["litellm_params"]
|
data = deployment["litellm_params"]
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in data: # prioritize model-specific params > default router params
|
||||||
|
@ -259,7 +264,7 @@ class Router:
|
||||||
**kwargs) -> Union[List[float], None]:
|
**kwargs) -> Union[List[float], None]:
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
deployment = self.get_available_deployment(model=model, input=input)
|
deployment = self.get_available_deployment(model=model, input=input)
|
||||||
|
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||||
data = deployment["litellm_params"]
|
data = deployment["litellm_params"]
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in data: # prioritize model-specific params > default router params
|
||||||
|
@ -315,10 +320,11 @@ class Router:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
kwargs["model"] = mg
|
kwargs["model"] = mg
|
||||||
|
kwargs["metadata"]["model_group"] = mg
|
||||||
response = await self.async_function_with_retries(*args, **kwargs)
|
response = await self.async_function_with_retries(*args, **kwargs)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.print_verbose(f"An exception occurred - {str(e)}")
|
self.print_verbose(f"An exception occurred - {str(e)}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
@ -328,12 +334,14 @@ class Router:
|
||||||
self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}")
|
self.print_verbose(f"Inside async function with retries: args - {args}; kwargs - {kwargs}")
|
||||||
backoff_factor = 1
|
backoff_factor = 1
|
||||||
original_function = kwargs.pop("original_function")
|
original_function = kwargs.pop("original_function")
|
||||||
|
self.print_verbose(f"async function w/ retries: original_function - {original_function}")
|
||||||
num_retries = kwargs.pop("num_retries")
|
num_retries = kwargs.pop("num_retries")
|
||||||
try:
|
try:
|
||||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||||
response = await original_function(*args, **kwargs)
|
response = await original_function(*args, **kwargs)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
original_exception = e
|
||||||
for current_attempt in range(num_retries):
|
for current_attempt in range(num_retries):
|
||||||
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
|
self.print_verbose(f"retrying request. Current attempt - {current_attempt}; num retries: {num_retries}")
|
||||||
try:
|
try:
|
||||||
|
@ -359,7 +367,7 @@ class Router:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
raise e
|
raise original_exception
|
||||||
|
|
||||||
def function_with_fallbacks(self, *args, **kwargs):
|
def function_with_fallbacks(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -499,10 +507,13 @@ class Router:
|
||||||
try:
|
try:
|
||||||
model_name = kwargs.get('model', None) # i.e. gpt35turbo
|
model_name = kwargs.get('model', None) # i.e. gpt35turbo
|
||||||
custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure
|
custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure
|
||||||
|
metadata = kwargs.get("litellm_params", {}).get('metadata', None)
|
||||||
|
if metadata:
|
||||||
|
deployment = metadata.get("deployment", None)
|
||||||
|
self._set_cooldown_deployments(deployment)
|
||||||
if custom_llm_provider:
|
if custom_llm_provider:
|
||||||
model_name = f"{custom_llm_provider}/{model_name}"
|
model_name = f"{custom_llm_provider}/{model_name}"
|
||||||
|
|
||||||
self._set_cooldown_deployments(model_name)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -731,6 +742,10 @@ class Router:
|
||||||
### get all deployments
|
### get all deployments
|
||||||
### filter out the deployments currently cooling down
|
### filter out the deployments currently cooling down
|
||||||
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
||||||
|
if len(healthy_deployments) == 0:
|
||||||
|
# check if the user sent in a deployment name instead
|
||||||
|
healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model]
|
||||||
|
self.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
||||||
deployments_to_remove = []
|
deployments_to_remove = []
|
||||||
cooldown_deployments = self._get_cooldown_deployments()
|
cooldown_deployments = self._get_cooldown_deployments()
|
||||||
self.print_verbose(f"cooldown deployments: {cooldown_deployments}")
|
self.print_verbose(f"cooldown deployments: {cooldown_deployments}")
|
||||||
|
|
|
@ -680,6 +680,7 @@ class Logging:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Logging Details Post-API Call: logger_fn - {self.logger_fn} | callable(logger_fn) - {callable(self.logger_fn)}"
|
f"Logging Details Post-API Call: logger_fn - {self.logger_fn} | callable(logger_fn) - {callable(self.logger_fn)}"
|
||||||
)
|
)
|
||||||
|
print_verbose(f"Logging Details Post-API Call: LiteLLM Params: {self.model_call_details}")
|
||||||
if self.logger_fn and callable(self.logger_fn):
|
if self.logger_fn and callable(self.logger_fn):
|
||||||
try:
|
try:
|
||||||
self.logger_fn(
|
self.logger_fn(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue