mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge branch 'main' into litellm_http_proxy_support
This commit is contained in:
commit
058813da76
199 changed files with 18866 additions and 1341 deletions
|
@ -94,11 +94,15 @@ class Router:
|
|||
timeout: Optional[float] = None,
|
||||
default_litellm_params={}, # default params for Router.chat.completion.create
|
||||
set_verbose: bool = False,
|
||||
debug_level: Literal["DEBUG", "INFO"] = "INFO",
|
||||
fallbacks: List = [],
|
||||
allowed_fails: Optional[int] = None,
|
||||
context_window_fallbacks: List = [],
|
||||
model_group_alias: Optional[dict] = {},
|
||||
retry_after: int = 0, # min time to wait before retrying a failed request
|
||||
allowed_fails: Optional[
|
||||
int
|
||||
] = None, # Number of times a deployment can failbefore being added to cooldown
|
||||
cooldown_time: float = 1, # (seconds) time to cooldown a deployment after failure
|
||||
routing_strategy: Literal[
|
||||
"simple-shuffle",
|
||||
"least-busy",
|
||||
|
@ -107,7 +111,42 @@ class Router:
|
|||
] = "simple-shuffle",
|
||||
routing_strategy_args: dict = {}, # just for latency-based routing
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Router class with the given parameters for caching, reliability, and routing strategy.
|
||||
|
||||
Args:
|
||||
model_list (Optional[list]): List of models to be used. Defaults to None.
|
||||
redis_url (Optional[str]): URL of the Redis server. Defaults to None.
|
||||
redis_host (Optional[str]): Hostname of the Redis server. Defaults to None.
|
||||
redis_port (Optional[int]): Port of the Redis server. Defaults to None.
|
||||
redis_password (Optional[str]): Password of the Redis server. Defaults to None.
|
||||
cache_responses (Optional[bool]): Flag to enable caching of responses. Defaults to False.
|
||||
cache_kwargs (dict): Additional kwargs to pass to RedisCache. Defaults to {}.
|
||||
caching_groups (Optional[List[tuple]]): List of model groups for caching across model groups. Defaults to None.
|
||||
client_ttl (int): Time-to-live for cached clients in seconds. Defaults to 3600.
|
||||
num_retries (int): Number of retries for failed requests. Defaults to 0.
|
||||
timeout (Optional[float]): Timeout for requests. Defaults to None.
|
||||
default_litellm_params (dict): Default parameters for Router.chat.completion.create. Defaults to {}.
|
||||
set_verbose (bool): Flag to set verbose mode. Defaults to False.
|
||||
debug_level (Literal["DEBUG", "INFO"]): Debug level for logging. Defaults to "INFO".
|
||||
fallbacks (List): List of fallback options. Defaults to [].
|
||||
context_window_fallbacks (List): List of context window fallback options. Defaults to [].
|
||||
model_group_alias (Optional[dict]): Alias for model groups. Defaults to {}.
|
||||
retry_after (int): Minimum time to wait before retrying a failed request. Defaults to 0.
|
||||
allowed_fails (Optional[int]): Number of allowed fails before adding to cooldown. Defaults to None.
|
||||
cooldown_time (float): Time to cooldown a deployment after failure in seconds. Defaults to 1.
|
||||
routing_strategy (Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"]): Routing strategy. Defaults to "simple-shuffle".
|
||||
routing_strategy_args (dict): Additional args for latency-based routing. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
Router: An instance of the litellm.Router class.
|
||||
"""
|
||||
self.set_verbose = set_verbose
|
||||
if self.set_verbose:
|
||||
if debug_level == "INFO":
|
||||
verbose_router_logger.setLevel(logging.INFO)
|
||||
elif debug_level == "DEBUG":
|
||||
verbose_router_logger.setLevel(logging.DEBUG)
|
||||
self.deployment_names: List = (
|
||||
[]
|
||||
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||
|
@ -157,6 +196,7 @@ class Router:
|
|||
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
|
||||
|
||||
self.allowed_fails = allowed_fails or litellm.allowed_fails
|
||||
self.cooldown_time = cooldown_time or 1
|
||||
self.failed_calls = (
|
||||
InMemoryCache()
|
||||
) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
|
||||
|
@ -249,16 +289,13 @@ class Router:
|
|||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
# Submit the function to the executor with a timeout
|
||||
future = executor.submit(self.function_with_fallbacks, **kwargs)
|
||||
response = future.result(timeout=timeout) # type: ignore
|
||||
|
||||
response = self.function_with_fallbacks(**kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _completion(self, model: str, messages: List[Dict[str, str]], **kwargs):
|
||||
model_name = None
|
||||
try:
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(
|
||||
|
@ -271,6 +308,7 @@ class Router:
|
|||
)
|
||||
data = deployment["litellm_params"].copy()
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
model_name = data["model"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if (
|
||||
k not in kwargs
|
||||
|
@ -292,7 +330,7 @@ class Router:
|
|||
else:
|
||||
model_client = potential_model_client
|
||||
|
||||
return litellm.completion(
|
||||
response = litellm.completion(
|
||||
**{
|
||||
**data,
|
||||
"messages": messages,
|
||||
|
@ -301,7 +339,14 @@ class Router:
|
|||
**kwargs,
|
||||
}
|
||||
)
|
||||
verbose_router_logger.info(
|
||||
f"litellm.completion(model={model_name})\033[32m 200 OK\033[0m"
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_router_logger.info(
|
||||
f"litellm.completion(model={model_name})\033[31m Exception {str(e)}\033[0m"
|
||||
)
|
||||
raise e
|
||||
|
||||
async def acompletion(self, model: str, messages: List[Dict[str, str]], **kwargs):
|
||||
|
@ -830,6 +875,9 @@ class Router:
|
|||
"""
|
||||
try:
|
||||
kwargs["model"] = mg
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"model_group": mg}
|
||||
) # update model_group used, if fallbacks are done
|
||||
response = await self.async_function_with_retries(
|
||||
*args, **kwargs
|
||||
)
|
||||
|
@ -858,8 +906,10 @@ class Router:
|
|||
f"Falling back to model_group = {mg}"
|
||||
)
|
||||
kwargs["model"] = mg
|
||||
kwargs["metadata"]["model_group"] = mg
|
||||
response = await self.async_function_with_retries(
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"model_group": mg}
|
||||
) # update model_group used, if fallbacks are done
|
||||
response = await self.async_function_with_fallbacks(
|
||||
*args, **kwargs
|
||||
)
|
||||
return response
|
||||
|
@ -1024,6 +1074,9 @@ class Router:
|
|||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||
kwargs["model"] = mg
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"model_group": mg}
|
||||
) # update model_group used, if fallbacks are done
|
||||
response = self.function_with_fallbacks(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
|
@ -1047,6 +1100,9 @@ class Router:
|
|||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=original_exception)
|
||||
kwargs["model"] = mg
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"model_group": mg}
|
||||
) # update model_group used, if fallbacks are done
|
||||
response = self.function_with_fallbacks(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
|
@ -1232,6 +1288,7 @@ class Router:
|
|||
verbose_router_logger.debug(
|
||||
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}"
|
||||
)
|
||||
cooldown_time = self.cooldown_time or 1
|
||||
if updated_fails > self.allowed_fails:
|
||||
# get the current cooldown list for that minute
|
||||
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
|
||||
|
@ -1245,13 +1302,19 @@ class Router:
|
|||
else:
|
||||
cached_value = cached_value + [deployment]
|
||||
# save updated value
|
||||
self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=1)
|
||||
self.cache.set_cache(
|
||||
value=cached_value, key=cooldown_key, ttl=cooldown_time
|
||||
)
|
||||
except:
|
||||
cached_value = [deployment]
|
||||
# save updated value
|
||||
self.cache.set_cache(value=cached_value, key=cooldown_key, ttl=1)
|
||||
self.cache.set_cache(
|
||||
value=cached_value, key=cooldown_key, ttl=cooldown_time
|
||||
)
|
||||
else:
|
||||
self.failed_calls.set_cache(key=deployment, value=updated_fails, ttl=1)
|
||||
self.failed_calls.set_cache(
|
||||
key=deployment, value=updated_fails, ttl=cooldown_time
|
||||
)
|
||||
|
||||
def _get_cooldown_deployments(self):
|
||||
"""
|
||||
|
@ -1344,6 +1407,7 @@ class Router:
|
|||
max_retries = litellm.get_secret(max_retries_env_name)
|
||||
litellm_params["max_retries"] = max_retries
|
||||
|
||||
|
||||
# proxy support
|
||||
import os
|
||||
import httpx
|
||||
|
@ -1369,6 +1433,12 @@ class Router:
|
|||
),
|
||||
}
|
||||
|
||||
organization = litellm_params.get("organization", None)
|
||||
if isinstance(organization, str) and organization.startswith("os.environ/"):
|
||||
organization_env_name = organization.replace("os.environ/", "")
|
||||
organization = litellm.get_secret(organization_env_name)
|
||||
litellm_params["organization"] = organization
|
||||
|
||||
if "azure" in model_name:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
|
@ -1576,6 +1646,7 @@ class Router:
|
|||
base_url=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
http_client=httpx.AsyncClient(
|
||||
transport=AsyncCustomHTTPTransport(),
|
||||
limits=httpx.Limits(
|
||||
|
@ -1597,6 +1668,7 @@ class Router:
|
|||
base_url=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
http_client=httpx.Client(
|
||||
transport=CustomHTTPTransport(),
|
||||
limits=httpx.Limits(
|
||||
|
@ -1619,6 +1691,7 @@ class Router:
|
|||
base_url=api_base,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
http_client=httpx.AsyncClient(
|
||||
transport=AsyncCustomHTTPTransport(),
|
||||
limits=httpx.Limits(
|
||||
|
@ -1641,6 +1714,7 @@ class Router:
|
|||
base_url=api_base,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
http_client=httpx.Client(
|
||||
transport=CustomHTTPTransport(),
|
||||
limits=httpx.Limits(
|
||||
|
@ -1865,6 +1939,9 @@ class Router:
|
|||
selected_index = random.choices(range(len(rpms)), weights=weights)[0]
|
||||
verbose_router_logger.debug(f"\n selected index, {selected_index}")
|
||||
deployment = healthy_deployments[selected_index]
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {deployment or deployment[0]} for model: {model}"
|
||||
)
|
||||
return deployment or deployment[0]
|
||||
############## Check if we can do a RPM/TPM based weighted pick #################
|
||||
tpm = healthy_deployments[0].get("litellm_params").get("tpm", None)
|
||||
|
@ -1879,6 +1956,9 @@ class Router:
|
|||
selected_index = random.choices(range(len(tpms)), weights=weights)[0]
|
||||
verbose_router_logger.debug(f"\n selected index, {selected_index}")
|
||||
deployment = healthy_deployments[selected_index]
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {deployment or deployment[0]} for model: {model}"
|
||||
)
|
||||
return deployment or deployment[0]
|
||||
|
||||
############## No RPM/TPM passed, we do a random pick #################
|
||||
|
@ -1903,8 +1983,13 @@ class Router:
|
|||
)
|
||||
|
||||
if deployment is None:
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, No deployment available"
|
||||
)
|
||||
raise ValueError("No models available.")
|
||||
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {deployment} for model: {model}"
|
||||
)
|
||||
return deployment
|
||||
|
||||
def flush_cache(self):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue