forked from phoenix/litellm-mirror
LiteLLM Minor Fixes and Improvements (09/13/2024) (#5689)
* refactor: cleanup unused variables + fix pyright errors * feat(health_check.py): Closes https://github.com/BerriAI/litellm/issues/5686 * fix(o1_reasoning.py): add stricter check for o-1 reasoning model * refactor(mistral/): make it easier to see mistral transformation logic * fix(openai.py): fix openai o-1 model param mapping Fixes https://github.com/BerriAI/litellm/issues/5685 * feat(main.py): infer finetuned gemini model from base model Fixes https://github.com/BerriAI/litellm/issues/5678 * docs(vertex.md): update docs to call finetuned gemini models * feat(proxy_server.py): allow admin to hide proxy model aliases Closes https://github.com/BerriAI/litellm/issues/5692 * docs(load_balancing.md): add docs on hiding alias models from proxy config * fix(base.py): don't raise notimplemented error * fix(user_api_key_auth.py): fix model max budget check * fix(router.py): fix elif * fix(user_api_key_auth.py): don't set team_id to empty str * fix(team_endpoints.py): fix response type * test(test_completion.py): handle predibase error * test(test_proxy_server.py): fix test * fix(o1_transformation.py): fix max_completion_token mapping * test(test_image_generation.py): mark flaky test
This commit is contained in:
parent
db3af20d84
commit
60709a0753
35 changed files with 1020 additions and 539 deletions
|
@ -92,6 +92,7 @@ from litellm.types.router import (
|
|||
RetryPolicy,
|
||||
RouterErrors,
|
||||
RouterGeneralSettings,
|
||||
RouterModelGroupAliasItem,
|
||||
RouterRateLimitError,
|
||||
RouterRateLimitErrorBasic,
|
||||
updateDeployment,
|
||||
|
@ -105,6 +106,7 @@ from litellm.utils import (
|
|||
calculate_max_parallel_requests,
|
||||
create_proxy_transport_and_mounts,
|
||||
get_llm_provider,
|
||||
get_secret,
|
||||
get_utc_datetime,
|
||||
)
|
||||
|
||||
|
@ -156,7 +158,9 @@ class Router:
|
|||
fallbacks: List = [],
|
||||
context_window_fallbacks: List = [],
|
||||
content_policy_fallbacks: List = [],
|
||||
model_group_alias: Optional[dict] = {},
|
||||
model_group_alias: Optional[
|
||||
Dict[str, Union[str, RouterModelGroupAliasItem]]
|
||||
] = {},
|
||||
enable_pre_call_checks: bool = False,
|
||||
enable_tag_filtering: bool = False,
|
||||
retry_after: int = 0, # min time to wait before retrying a failed request
|
||||
|
@ -331,7 +335,8 @@ class Router:
|
|||
self.set_model_list(model_list)
|
||||
self.healthy_deployments: List = self.model_list # type: ignore
|
||||
for m in model_list:
|
||||
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
|
||||
if "model" in m["litellm_params"]:
|
||||
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
|
||||
else:
|
||||
self.model_list: List = (
|
||||
[]
|
||||
|
@ -398,7 +403,7 @@ class Router:
|
|||
self.previous_models: List = (
|
||||
[]
|
||||
) # list to store failed calls (passed in as metadata to next call)
|
||||
self.model_group_alias: dict = (
|
||||
self.model_group_alias: Dict[str, Union[str, RouterModelGroupAliasItem]] = (
|
||||
model_group_alias or {}
|
||||
) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group
|
||||
|
||||
|
@ -1179,6 +1184,7 @@ class Router:
|
|||
raise e
|
||||
|
||||
def _image_generation(self, prompt: str, model: str, **kwargs):
|
||||
model_name = ""
|
||||
try:
|
||||
verbose_router_logger.debug(
|
||||
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
|
||||
|
@ -1269,6 +1275,7 @@ class Router:
|
|||
raise e
|
||||
|
||||
async def _aimage_generation(self, prompt: str, model: str, **kwargs):
|
||||
model_name = ""
|
||||
try:
|
||||
verbose_router_logger.debug(
|
||||
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
|
||||
|
@ -1401,6 +1408,7 @@ class Router:
|
|||
raise e
|
||||
|
||||
async def _atranscription(self, file: FileTypes, model: str, **kwargs):
|
||||
model_name = model
|
||||
try:
|
||||
verbose_router_logger.debug(
|
||||
f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
|
||||
|
@ -1781,6 +1789,7 @@ class Router:
|
|||
is_async: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["prompt"] = prompt
|
||||
|
@ -1789,7 +1798,6 @@ class Router:
|
|||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(
|
||||
model=model,
|
||||
|
@ -2534,13 +2542,13 @@ class Router:
|
|||
try:
|
||||
# Update kwargs with the current model name or any other model-specific adjustments
|
||||
## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ##
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
|
||||
model=model_name["litellm_params"]["model"]
|
||||
)
|
||||
new_kwargs = copy.deepcopy(kwargs)
|
||||
new_kwargs.pop("custom_llm_provider", None)
|
||||
return await litellm.aretrieve_batch(
|
||||
custom_llm_provider=custom_llm_provider, **new_kwargs
|
||||
custom_llm_provider=custom_llm_provider, **new_kwargs # type: ignore
|
||||
)
|
||||
except Exception as e:
|
||||
receieved_exceptions.append(e)
|
||||
|
@ -2616,13 +2624,13 @@ class Router:
|
|||
for result in results:
|
||||
if result is not None:
|
||||
## check batch id
|
||||
if final_results["first_id"] is None:
|
||||
final_results["first_id"] = result.first_id
|
||||
final_results["last_id"] = result.last_id
|
||||
if final_results["first_id"] is None and hasattr(result, "first_id"):
|
||||
final_results["first_id"] = getattr(result, "first_id")
|
||||
final_results["last_id"] = getattr(result, "last_id")
|
||||
final_results["data"].extend(result.data) # type: ignore
|
||||
|
||||
## check 'has_more'
|
||||
if result.has_more is True:
|
||||
if getattr(result, "has_more", False) is True:
|
||||
final_results["has_more"] = True
|
||||
|
||||
return final_results
|
||||
|
@ -2874,8 +2882,12 @@ class Router:
|
|||
verbose_router_logger.debug(f"Traceback{traceback.format_exc()}")
|
||||
original_exception = e
|
||||
fallback_model_group = None
|
||||
original_model_group = kwargs.get("model")
|
||||
original_model_group: Optional[str] = kwargs.get("model") # type: ignore
|
||||
fallback_failure_exception_str = ""
|
||||
|
||||
if original_model_group is None:
|
||||
raise e
|
||||
|
||||
try:
|
||||
verbose_router_logger.debug("Trying to fallback b/w models")
|
||||
if isinstance(e, litellm.ContextWindowExceededError):
|
||||
|
@ -2972,7 +2984,7 @@ class Router:
|
|||
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
|
||||
)
|
||||
if hasattr(original_exception, "message"):
|
||||
original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
|
||||
original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" # type: ignore
|
||||
raise original_exception
|
||||
|
||||
response = await run_async_fallback(
|
||||
|
@ -2996,12 +3008,12 @@ class Router:
|
|||
|
||||
if hasattr(original_exception, "message"):
|
||||
# add the available fallbacks to the exception
|
||||
original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format(
|
||||
original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore
|
||||
model_group,
|
||||
fallback_model_group,
|
||||
)
|
||||
if len(fallback_failure_exception_str) > 0:
|
||||
original_exception.message += (
|
||||
original_exception.message += ( # type: ignore
|
||||
"\nError doing the fallback: {}".format(
|
||||
fallback_failure_exception_str
|
||||
)
|
||||
|
@ -3117,9 +3129,15 @@ class Router:
|
|||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||
remaining_retries = num_retries - current_attempt
|
||||
_healthy_deployments, _ = await self._async_get_healthy_deployments(
|
||||
model=kwargs.get("model"),
|
||||
)
|
||||
_model: Optional[str] = kwargs.get("model") # type: ignore
|
||||
if _model is not None:
|
||||
_healthy_deployments, _ = (
|
||||
await self._async_get_healthy_deployments(
|
||||
model=_model,
|
||||
)
|
||||
)
|
||||
else:
|
||||
_healthy_deployments = []
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
e=original_exception,
|
||||
remaining_retries=remaining_retries,
|
||||
|
@ -3129,8 +3147,8 @@ class Router:
|
|||
await asyncio.sleep(_timeout)
|
||||
|
||||
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
original_exception.max_retries = num_retries
|
||||
original_exception.num_retries = current_attempt
|
||||
setattr(original_exception, "max_retries", num_retries)
|
||||
setattr(original_exception, "num_retries", current_attempt)
|
||||
|
||||
raise original_exception
|
||||
|
||||
|
@ -3225,8 +3243,12 @@ class Router:
|
|||
return response
|
||||
except Exception as e:
|
||||
original_exception = e
|
||||
original_model_group = kwargs.get("model")
|
||||
original_model_group: Optional[str] = kwargs.get("model")
|
||||
verbose_router_logger.debug(f"An exception occurs {original_exception}")
|
||||
|
||||
if original_model_group is None:
|
||||
raise e
|
||||
|
||||
try:
|
||||
verbose_router_logger.debug(
|
||||
f"Trying to fallback b/w models. Initial model group: {model_group}"
|
||||
|
@ -3336,10 +3358,10 @@ class Router:
|
|||
return 0
|
||||
|
||||
response_headers: Optional[httpx.Headers] = None
|
||||
if hasattr(e, "response") and hasattr(e.response, "headers"):
|
||||
response_headers = e.response.headers
|
||||
if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
|
||||
response_headers = e.response.headers # type: ignore
|
||||
elif hasattr(e, "litellm_response_headers"):
|
||||
response_headers = e.litellm_response_headers
|
||||
response_headers = e.litellm_response_headers # type: ignore
|
||||
|
||||
if response_headers is not None:
|
||||
timeout = litellm._calculate_retry_after(
|
||||
|
@ -3398,9 +3420,13 @@ class Router:
|
|||
except Exception as e:
|
||||
current_attempt = None
|
||||
original_exception = e
|
||||
_model: Optional[str] = kwargs.get("model") # type: ignore
|
||||
|
||||
if _model is None:
|
||||
raise e # re-raise error, if model can't be determined for loadbalancing
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
|
||||
_healthy_deployments, _all_deployments = self._get_healthy_deployments(
|
||||
model=kwargs.get("model"),
|
||||
model=_model,
|
||||
)
|
||||
|
||||
# raises an exception if this error should not be retries
|
||||
|
@ -3438,8 +3464,12 @@ class Router:
|
|||
except Exception as e:
|
||||
## LOGGING
|
||||
kwargs = self.log_retry(kwargs=kwargs, e=e)
|
||||
_model: Optional[str] = kwargs.get("model") # type: ignore
|
||||
|
||||
if _model is None:
|
||||
raise e # re-raise error, if model can't be determined for loadbalancing
|
||||
_healthy_deployments, _ = self._get_healthy_deployments(
|
||||
model=kwargs.get("model"),
|
||||
model=_model,
|
||||
)
|
||||
remaining_retries = num_retries - current_attempt
|
||||
_timeout = self._time_to_sleep_before_retry(
|
||||
|
@ -4055,7 +4085,7 @@ class Router:
|
|||
if isinstance(_litellm_params, dict):
|
||||
for k, v in _litellm_params.items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
_litellm_params[k] = litellm.get_secret(v)
|
||||
_litellm_params[k] = get_secret(v)
|
||||
|
||||
_model_info: dict = model.pop("model_info", {})
|
||||
|
||||
|
@ -4392,7 +4422,6 @@ class Router:
|
|||
- ModelGroupInfo if able to construct a model group
|
||||
- None if error constructing model group info
|
||||
"""
|
||||
|
||||
model_group_info: Optional[ModelGroupInfo] = None
|
||||
|
||||
total_tpm: Optional[int] = None
|
||||
|
@ -4557,12 +4586,23 @@ class Router:
|
|||
|
||||
Returns:
|
||||
- ModelGroupInfo if able to construct a model group
|
||||
- None if error constructing model group info
|
||||
- None if error constructing model group info or hidden model group
|
||||
"""
|
||||
## Check if model group alias
|
||||
if model_group in self.model_group_alias:
|
||||
item = self.model_group_alias[model_group]
|
||||
if isinstance(item, str):
|
||||
_router_model_group = item
|
||||
elif isinstance(item, dict):
|
||||
if item["hidden"] is True:
|
||||
return None
|
||||
else:
|
||||
_router_model_group = item["model"]
|
||||
else:
|
||||
return None
|
||||
|
||||
return self._set_model_group_info(
|
||||
model_group=self.model_group_alias[model_group],
|
||||
model_group=_router_model_group,
|
||||
user_facing_model_group_name=model_group,
|
||||
)
|
||||
|
||||
|
@ -4666,7 +4706,14 @@ class Router:
|
|||
|
||||
Includes model_group_alias models too.
|
||||
"""
|
||||
return self.model_names + list(self.model_group_alias.keys())
|
||||
model_list = self.get_model_list()
|
||||
if model_list is None:
|
||||
return []
|
||||
|
||||
model_names = []
|
||||
for m in model_list:
|
||||
model_names.append(m["model_name"])
|
||||
return model_names
|
||||
|
||||
def get_model_list(
|
||||
self, model_name: Optional[str] = None
|
||||
|
@ -4678,9 +4725,21 @@ class Router:
|
|||
returned_models: List[DeploymentTypedDict] = []
|
||||
|
||||
for model_alias, model_value in self.model_group_alias.items():
|
||||
|
||||
if isinstance(model_value, str):
|
||||
_router_model_name: str = model_value
|
||||
elif isinstance(model_value, dict):
|
||||
_model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
|
||||
if _model_value["hidden"] is True:
|
||||
continue
|
||||
else:
|
||||
_router_model_name = _model_value["model"]
|
||||
else:
|
||||
continue
|
||||
|
||||
returned_models.extend(
|
||||
self._get_all_deployments(
|
||||
model_name=model_value, model_alias=model_alias
|
||||
model_name=_router_model_name, model_alias=model_alias
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -5078,10 +5137,11 @@ class Router:
|
|||
)
|
||||
|
||||
if model in self.model_group_alias:
|
||||
verbose_router_logger.debug(
|
||||
f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}"
|
||||
)
|
||||
model = self.model_group_alias[model]
|
||||
_item = self.model_group_alias[model]
|
||||
if isinstance(_item, str):
|
||||
model = _item
|
||||
else:
|
||||
model = _item["model"]
|
||||
|
||||
if model not in self.model_names:
|
||||
# check if provider/ specific wildcard routing
|
||||
|
@ -5124,7 +5184,9 @@ class Router:
|
|||
m for m in self.model_list if m["litellm_params"]["model"] == model
|
||||
]
|
||||
|
||||
litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
||||
verbose_router_logger.debug(
|
||||
f"initial list of deployments: {healthy_deployments}"
|
||||
)
|
||||
|
||||
if len(healthy_deployments) == 0:
|
||||
raise ValueError(
|
||||
|
@ -5208,7 +5270,7 @@ class Router:
|
|||
)
|
||||
|
||||
# check if user wants to do tag based routing
|
||||
healthy_deployments = await get_deployments_for_tag(
|
||||
healthy_deployments = await get_deployments_for_tag( # type: ignore
|
||||
llm_router_instance=self,
|
||||
request_kwargs=request_kwargs,
|
||||
healthy_deployments=healthy_deployments,
|
||||
|
@ -5241,7 +5303,7 @@ class Router:
|
|||
input=input,
|
||||
)
|
||||
)
|
||||
if (
|
||||
elif (
|
||||
self.routing_strategy == "cost-based-routing"
|
||||
and self.lowestcost_logger is not None
|
||||
):
|
||||
|
@ -5326,6 +5388,8 @@ class Router:
|
|||
############## No RPM/TPM passed, we do a random pick #################
|
||||
item = random.choice(healthy_deployments)
|
||||
return item or item[0]
|
||||
else:
|
||||
deployment = None
|
||||
if deployment is None:
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, No deployment available"
|
||||
|
@ -5515,6 +5579,9 @@ class Router:
|
|||
messages=messages,
|
||||
input=input,
|
||||
)
|
||||
else:
|
||||
deployment = None
|
||||
|
||||
if deployment is None:
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, No deployment available"
|
||||
|
@ -5690,6 +5757,9 @@ class Router:
|
|||
def _initialize_alerting(self):
|
||||
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
|
||||
|
||||
if self.alerting_config is None:
|
||||
return
|
||||
|
||||
router_alerting_config: AlertingConfig = self.alerting_config
|
||||
|
||||
_slack_alerting_logger = SlackAlerting(
|
||||
|
@ -5700,7 +5770,7 @@ class Router:
|
|||
|
||||
self.slack_alerting_logger = _slack_alerting_logger
|
||||
|
||||
litellm.callbacks.append(_slack_alerting_logger)
|
||||
litellm.callbacks.append(_slack_alerting_logger) # type: ignore
|
||||
litellm.success_callback.append(
|
||||
_slack_alerting_logger.response_taking_too_long_callback
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue