LiteLLM Minor Fixes and Improvements (09/13/2024) (#5689)

* refactor: cleanup unused variables + fix pyright errors

* feat(health_check.py): Closes https://github.com/BerriAI/litellm/issues/5686

* fix(o1_reasoning.py): add stricter check for o-1 reasoning model

* refactor(mistral/): make it easier to see mistral transformation logic

* fix(openai.py): fix openai o-1 model param mapping

Fixes https://github.com/BerriAI/litellm/issues/5685

* feat(main.py): infer finetuned gemini model from base model

Fixes https://github.com/BerriAI/litellm/issues/5678

* docs(vertex.md): update docs to call finetuned gemini models

* feat(proxy_server.py): allow admin to hide proxy model aliases

Closes https://github.com/BerriAI/litellm/issues/5692

* docs(load_balancing.md): add docs on hiding alias models from proxy config

* fix(base.py): don't raise notimplemented error

* fix(user_api_key_auth.py): fix model max budget check

* fix(router.py): fix elif

* fix(user_api_key_auth.py): don't set team_id to empty str

* fix(team_endpoints.py): fix response type

* test(test_completion.py): handle predibase error

* test(test_proxy_server.py): fix test

* fix(o1_transformation.py): fix max_completion_token mapping

* test(test_image_generation.py): mark flaky test
This commit is contained in:
Krish Dholakia 2024-09-14 10:02:55 -07:00 committed by GitHub
parent 60c5d3ebec
commit 713d762411
35 changed files with 1020 additions and 539 deletions

View file

@ -92,6 +92,7 @@ from litellm.types.router import (
RetryPolicy,
RouterErrors,
RouterGeneralSettings,
RouterModelGroupAliasItem,
RouterRateLimitError,
RouterRateLimitErrorBasic,
updateDeployment,
@ -105,6 +106,7 @@ from litellm.utils import (
calculate_max_parallel_requests,
create_proxy_transport_and_mounts,
get_llm_provider,
get_secret,
get_utc_datetime,
)
@ -156,7 +158,9 @@ class Router:
fallbacks: List = [],
context_window_fallbacks: List = [],
content_policy_fallbacks: List = [],
model_group_alias: Optional[dict] = {},
model_group_alias: Optional[
Dict[str, Union[str, RouterModelGroupAliasItem]]
] = {},
enable_pre_call_checks: bool = False,
enable_tag_filtering: bool = False,
retry_after: int = 0, # min time to wait before retrying a failed request
@ -331,7 +335,8 @@ class Router:
self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list # type: ignore
for m in model_list:
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
if "model" in m["litellm_params"]:
self.deployment_latency_map[m["litellm_params"]["model"]] = 0
else:
self.model_list: List = (
[]
@ -398,7 +403,7 @@ class Router:
self.previous_models: List = (
[]
) # list to store failed calls (passed in as metadata to next call)
self.model_group_alias: dict = (
self.model_group_alias: Dict[str, Union[str, RouterModelGroupAliasItem]] = (
model_group_alias or {}
) # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group
@ -1179,6 +1184,7 @@ class Router:
raise e
def _image_generation(self, prompt: str, model: str, **kwargs):
model_name = ""
try:
verbose_router_logger.debug(
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
@ -1269,6 +1275,7 @@ class Router:
raise e
async def _aimage_generation(self, prompt: str, model: str, **kwargs):
model_name = ""
try:
verbose_router_logger.debug(
f"Inside _image_generation()- model: {model}; kwargs: {kwargs}"
@ -1401,6 +1408,7 @@ class Router:
raise e
async def _atranscription(self, file: FileTypes, model: str, **kwargs):
model_name = model
try:
verbose_router_logger.debug(
f"Inside _atranscription()- model: {model}; kwargs: {kwargs}"
@ -1781,6 +1789,7 @@ class Router:
is_async: Optional[bool] = False,
**kwargs,
):
messages = [{"role": "user", "content": prompt}]
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
@ -1789,7 +1798,6 @@ class Router:
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
messages = [{"role": "user", "content": prompt}]
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(
model=model,
@ -2534,13 +2542,13 @@ class Router:
try:
# Update kwargs with the current model name or any other model-specific adjustments
## SET CUSTOM PROVIDER TO SELECTED DEPLOYMENT ##
_, custom_llm_provider, _, _ = get_llm_provider(
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model=model_name["litellm_params"]["model"]
)
new_kwargs = copy.deepcopy(kwargs)
new_kwargs.pop("custom_llm_provider", None)
return await litellm.aretrieve_batch(
custom_llm_provider=custom_llm_provider, **new_kwargs
custom_llm_provider=custom_llm_provider, **new_kwargs # type: ignore
)
except Exception as e:
receieved_exceptions.append(e)
@ -2616,13 +2624,13 @@ class Router:
for result in results:
if result is not None:
## check batch id
if final_results["first_id"] is None:
final_results["first_id"] = result.first_id
final_results["last_id"] = result.last_id
if final_results["first_id"] is None and hasattr(result, "first_id"):
final_results["first_id"] = getattr(result, "first_id")
final_results["last_id"] = getattr(result, "last_id")
final_results["data"].extend(result.data) # type: ignore
## check 'has_more'
if result.has_more is True:
if getattr(result, "has_more", False) is True:
final_results["has_more"] = True
return final_results
@ -2874,8 +2882,12 @@ class Router:
verbose_router_logger.debug(f"Traceback{traceback.format_exc()}")
original_exception = e
fallback_model_group = None
original_model_group = kwargs.get("model")
original_model_group: Optional[str] = kwargs.get("model") # type: ignore
fallback_failure_exception_str = ""
if original_model_group is None:
raise e
try:
verbose_router_logger.debug("Trying to fallback b/w models")
if isinstance(e, litellm.ContextWindowExceededError):
@ -2972,7 +2984,7 @@ class Router:
f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
)
if hasattr(original_exception, "message"):
original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}"
original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" # type: ignore
raise original_exception
response = await run_async_fallback(
@ -2996,12 +3008,12 @@ class Router:
if hasattr(original_exception, "message"):
# add the available fallbacks to the exception
original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format(
original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore
model_group,
fallback_model_group,
)
if len(fallback_failure_exception_str) > 0:
original_exception.message += (
original_exception.message += ( # type: ignore
"\nError doing the fallback: {}".format(
fallback_failure_exception_str
)
@ -3117,9 +3129,15 @@ class Router:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e)
remaining_retries = num_retries - current_attempt
_healthy_deployments, _ = await self._async_get_healthy_deployments(
model=kwargs.get("model"),
)
_model: Optional[str] = kwargs.get("model") # type: ignore
if _model is not None:
_healthy_deployments, _ = (
await self._async_get_healthy_deployments(
model=_model,
)
)
else:
_healthy_deployments = []
_timeout = self._time_to_sleep_before_retry(
e=original_exception,
remaining_retries=remaining_retries,
@ -3129,8 +3147,8 @@ class Router:
await asyncio.sleep(_timeout)
if type(original_exception) in litellm.LITELLM_EXCEPTION_TYPES:
original_exception.max_retries = num_retries
original_exception.num_retries = current_attempt
setattr(original_exception, "max_retries", num_retries)
setattr(original_exception, "num_retries", current_attempt)
raise original_exception
@ -3225,8 +3243,12 @@ class Router:
return response
except Exception as e:
original_exception = e
original_model_group = kwargs.get("model")
original_model_group: Optional[str] = kwargs.get("model")
verbose_router_logger.debug(f"An exception occurs {original_exception}")
if original_model_group is None:
raise e
try:
verbose_router_logger.debug(
f"Trying to fallback b/w models. Initial model group: {model_group}"
@ -3336,10 +3358,10 @@ class Router:
return 0
response_headers: Optional[httpx.Headers] = None
if hasattr(e, "response") and hasattr(e.response, "headers"):
response_headers = e.response.headers
if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore
response_headers = e.response.headers # type: ignore
elif hasattr(e, "litellm_response_headers"):
response_headers = e.litellm_response_headers
response_headers = e.litellm_response_headers # type: ignore
if response_headers is not None:
timeout = litellm._calculate_retry_after(
@ -3398,9 +3420,13 @@ class Router:
except Exception as e:
current_attempt = None
original_exception = e
_model: Optional[str] = kwargs.get("model") # type: ignore
if _model is None:
raise e # re-raise error, if model can't be determined for loadbalancing
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR
_healthy_deployments, _all_deployments = self._get_healthy_deployments(
model=kwargs.get("model"),
model=_model,
)
# raises an exception if this error should not be retries
@ -3438,8 +3464,12 @@ class Router:
except Exception as e:
## LOGGING
kwargs = self.log_retry(kwargs=kwargs, e=e)
_model: Optional[str] = kwargs.get("model") # type: ignore
if _model is None:
raise e # re-raise error, if model can't be determined for loadbalancing
_healthy_deployments, _ = self._get_healthy_deployments(
model=kwargs.get("model"),
model=_model,
)
remaining_retries = num_retries - current_attempt
_timeout = self._time_to_sleep_before_retry(
@ -4055,7 +4085,7 @@ class Router:
if isinstance(_litellm_params, dict):
for k, v in _litellm_params.items():
if isinstance(v, str) and v.startswith("os.environ/"):
_litellm_params[k] = litellm.get_secret(v)
_litellm_params[k] = get_secret(v)
_model_info: dict = model.pop("model_info", {})
@ -4392,7 +4422,6 @@ class Router:
- ModelGroupInfo if able to construct a model group
- None if error constructing model group info
"""
model_group_info: Optional[ModelGroupInfo] = None
total_tpm: Optional[int] = None
@ -4557,12 +4586,23 @@ class Router:
Returns:
- ModelGroupInfo if able to construct a model group
- None if error constructing model group info
- None if error constructing model group info or hidden model group
"""
## Check if model group alias
if model_group in self.model_group_alias:
item = self.model_group_alias[model_group]
if isinstance(item, str):
_router_model_group = item
elif isinstance(item, dict):
if item["hidden"] is True:
return None
else:
_router_model_group = item["model"]
else:
return None
return self._set_model_group_info(
model_group=self.model_group_alias[model_group],
model_group=_router_model_group,
user_facing_model_group_name=model_group,
)
@ -4666,7 +4706,14 @@ class Router:
Includes model_group_alias models too.
"""
return self.model_names + list(self.model_group_alias.keys())
model_list = self.get_model_list()
if model_list is None:
return []
model_names = []
for m in model_list:
model_names.append(m["model_name"])
return model_names
def get_model_list(
self, model_name: Optional[str] = None
@ -4678,9 +4725,21 @@ class Router:
returned_models: List[DeploymentTypedDict] = []
for model_alias, model_value in self.model_group_alias.items():
if isinstance(model_value, str):
_router_model_name: str = model_value
elif isinstance(model_value, dict):
_model_value = RouterModelGroupAliasItem(**model_value) # type: ignore
if _model_value["hidden"] is True:
continue
else:
_router_model_name = _model_value["model"]
else:
continue
returned_models.extend(
self._get_all_deployments(
model_name=model_value, model_alias=model_alias
model_name=_router_model_name, model_alias=model_alias
)
)
@ -5078,10 +5137,11 @@ class Router:
)
if model in self.model_group_alias:
verbose_router_logger.debug(
f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}"
)
model = self.model_group_alias[model]
_item = self.model_group_alias[model]
if isinstance(_item, str):
model = _item
else:
model = _item["model"]
if model not in self.model_names:
# check if provider/ specific wildcard routing
@ -5124,7 +5184,9 @@ class Router:
m for m in self.model_list if m["litellm_params"]["model"] == model
]
litellm.print_verbose(f"initial list of deployments: {healthy_deployments}")
verbose_router_logger.debug(
f"initial list of deployments: {healthy_deployments}"
)
if len(healthy_deployments) == 0:
raise ValueError(
@ -5208,7 +5270,7 @@ class Router:
)
# check if user wants to do tag based routing
healthy_deployments = await get_deployments_for_tag(
healthy_deployments = await get_deployments_for_tag( # type: ignore
llm_router_instance=self,
request_kwargs=request_kwargs,
healthy_deployments=healthy_deployments,
@ -5241,7 +5303,7 @@ class Router:
input=input,
)
)
if (
elif (
self.routing_strategy == "cost-based-routing"
and self.lowestcost_logger is not None
):
@ -5326,6 +5388,8 @@ class Router:
############## No RPM/TPM passed, we do a random pick #################
item = random.choice(healthy_deployments)
return item or item[0]
else:
deployment = None
if deployment is None:
verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available"
@ -5515,6 +5579,9 @@ class Router:
messages=messages,
input=input,
)
else:
deployment = None
if deployment is None:
verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available"
@ -5690,6 +5757,9 @@ class Router:
def _initialize_alerting(self):
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
if self.alerting_config is None:
return
router_alerting_config: AlertingConfig = self.alerting_config
_slack_alerting_logger = SlackAlerting(
@ -5700,7 +5770,7 @@ class Router:
self.slack_alerting_logger = _slack_alerting_logger
litellm.callbacks.append(_slack_alerting_logger)
litellm.callbacks.append(_slack_alerting_logger) # type: ignore
litellm.success_callback.append(
_slack_alerting_logger.response_taking_too_long_callback
)