Merge pull request #3828 from BerriAI/litellm_outage_alerting

fix(slack_alerting.py): support region based outage alerting
This commit is contained in:
Krish Dholakia 2024-05-24 19:13:17 -07:00 committed by GitHub
commit d25ed9c4d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 414 additions and 78 deletions

View file

@ -1,6 +1,6 @@
#### What this does #### #### What this does ####
# Class for sending Slack Alerts # # Class for sending Slack Alerts #
import dotenv, os import dotenv, os, traceback
from litellm.proxy._types import UserAPIKeyAuth, CallInfo from litellm.proxy._types import UserAPIKeyAuth, CallInfo
from litellm._logging import verbose_logger, verbose_proxy_logger from litellm._logging import verbose_logger, verbose_proxy_logger
import litellm, threading import litellm, threading
@ -15,6 +15,34 @@ from enum import Enum
from datetime import datetime as dt, timedelta, timezone from datetime import datetime as dt, timedelta, timezone
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
import random import random
from typing import TypedDict
from openai import APIError
import litellm.types
import litellm.types.router
class OutageModel(TypedDict):
model_id: str
alerts: List[int]
deployment_ids: List[str]
minor_alert_sent: bool
major_alert_sent: bool
last_updated_at: float
AlertType = Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"cooldown_deployment",
"new_model_added",
"outage_alerts",
]
class LiteLLMBase(BaseModel): class LiteLLMBase(BaseModel):
@ -37,6 +65,10 @@ class SlackAlertingArgs(LiteLLMBase):
) )
report_check_interval: int = 5 * 60 # 5 minutes report_check_interval: int = 5 * 60 # 5 minutes
budget_alert_ttl: int = 24 * 60 * 60 # 24 hours budget_alert_ttl: int = 24 * 60 * 60 # 24 hours
outage_alert_ttl: int = 1 * 60 * 60 # 1 hour
minor_outage_alert_threshold: int = 3
major_outage_alert_threshold: int = 10
max_outage_alert_list_size: int = 10 # prevent memory leak
class WebhookEvent(CallInfo): class WebhookEvent(CallInfo):
@ -86,19 +118,7 @@ class SlackAlerting(CustomLogger):
internal_usage_cache: Optional[DualCache] = None, internal_usage_cache: Optional[DualCache] = None,
alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds) alerting_threshold: float = 300, # threshold for slow / hanging llm responses (in seconds)
alerting: Optional[List] = [], alerting: Optional[List] = [],
alert_types: List[ alert_types: List[AlertType] = [
Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"cooldown_deployment",
"new_model_added",
]
] = [
"llm_exceptions", "llm_exceptions",
"llm_too_slow", "llm_too_slow",
"llm_requests_hanging", "llm_requests_hanging",
@ -108,6 +128,7 @@ class SlackAlerting(CustomLogger):
"spend_reports", "spend_reports",
"cooldown_deployment", "cooldown_deployment",
"new_model_added", "new_model_added",
"outage_alerts",
], ],
alert_to_webhook_url: Optional[ alert_to_webhook_url: Optional[
Dict Dict
@ -124,6 +145,7 @@ class SlackAlerting(CustomLogger):
self.is_running = False self.is_running = False
self.alerting_args = SlackAlertingArgs(**alerting_args) self.alerting_args = SlackAlertingArgs(**alerting_args)
self.default_webhook_url = default_webhook_url self.default_webhook_url = default_webhook_url
self.llm_router: Optional[litellm.Router] = None
def update_values( def update_values(
self, self,
@ -132,6 +154,7 @@ class SlackAlerting(CustomLogger):
alert_types: Optional[List] = None, alert_types: Optional[List] = None,
alert_to_webhook_url: Optional[Dict] = None, alert_to_webhook_url: Optional[Dict] = None,
alerting_args: Optional[Dict] = None, alerting_args: Optional[Dict] = None,
llm_router: Optional[litellm.Router] = None,
): ):
if alerting is not None: if alerting is not None:
self.alerting = alerting self.alerting = alerting
@ -147,6 +170,8 @@ class SlackAlerting(CustomLogger):
self.alert_to_webhook_url = alert_to_webhook_url self.alert_to_webhook_url = alert_to_webhook_url
else: else:
self.alert_to_webhook_url.update(alert_to_webhook_url) self.alert_to_webhook_url.update(alert_to_webhook_url)
if llm_router is not None:
self.llm_router = llm_router
async def deployment_in_cooldown(self): async def deployment_in_cooldown(self):
pass pass
@ -696,6 +721,158 @@ class SlackAlerting(CustomLogger):
return return
return return
def _count_outage_alerts(self, alerts: List[int]) -> str:
"""
Parameters:
- alerts: List[int] -> list of error codes (either 408 or 500+)
Returns:
- str -> formatted string. This is an alert message, giving a human-friendly description of the errors.
"""
error_breakdown = {"Timeout Errors": 0, "API Errors": 0, "Unknown Errors": 0}
for alert in alerts:
if alert == 408:
error_breakdown["Timeout Errors"] += 1
elif alert >= 500:
error_breakdown["API Errors"] += 1
else:
error_breakdown["Unknown Errors"] += 1
error_msg = ""
for key, value in error_breakdown.items():
if value > 0:
error_msg += "\n{}: {}\n".format(key, value)
return error_msg
async def outage_alerts(
self,
exception: APIError,
deployment_id: str,
) -> None:
"""
Send slack alert if model is badly configured / having an outage (408, 401, 429, >=500).
key = model_id
value = {
- model_id
- threshold
- alerts []
}
ttl = 1hr
max_alerts_size = 10
"""
try:
outage_value: Optional[OutageModel] = await self.internal_usage_cache.async_get_cache(key=deployment_id) # type: ignore
if (
getattr(exception, "status_code", None) is None
or (
exception.status_code != 408 # type: ignore
and exception.status_code < 500 # type: ignore
)
or self.llm_router is None
):
return
### EXTRACT MODEL DETAILS ###
deployment = self.llm_router.get_deployment(model_id=deployment_id)
if deployment is None:
return
model = deployment.litellm_params.model
provider = deployment.litellm_params.custom_llm_provider
if provider is None:
try:
model, provider, _, _ = litellm.get_llm_provider(model=model)
except Exception as e:
provider = ""
api_base = litellm.get_api_base(
model=model, optional_params=deployment.litellm_params
)
if outage_value is None:
outage_value = OutageModel(
model_id=deployment_id,
alerts=[exception.status_code], # type: ignore
deployment_ids=[deployment_id],
minor_alert_sent=False,
major_alert_sent=False,
last_updated_at=time.time(),
)
## add to cache ##
await self.internal_usage_cache.async_set_cache(
key=deployment_id,
value=outage_value,
ttl=self.alerting_args.outage_alert_ttl,
)
return
outage_value["alerts"].append(exception.status_code) # type: ignore
outage_value["deployment_ids"].append(deployment_id)
outage_value["last_updated_at"] = time.time()
## MINOR OUTAGE ALERT SENT ##
if (
outage_value["minor_alert_sent"] == False
and len(outage_value["alerts"])
>= self.alerting_args.minor_outage_alert_threshold
):
msg = f"""\n\n
* Minor Service Outage*
*Model Name:* `{model}`
*Provider:* `{provider}`
*API Base:* `{api_base}`
*Errors:*
{self._count_outage_alerts(alerts=outage_value["alerts"])}
*Last Check:* `{round(time.time() - outage_value["last_updated_at"], 4)}s ago`\n\n
"""
# send minor alert
_result_val = self.send_alert(
message=msg, level="Medium", alert_type="outage_alerts"
)
if _result_val is not None:
await _result_val
# set to true
outage_value["minor_alert_sent"] = True
elif (
outage_value["major_alert_sent"] == False
and len(outage_value["alerts"])
>= self.alerting_args.major_outage_alert_threshold
):
msg = f"""\n\n
* Major Service Outage*
*Model Name:* `{model}`
*Provider:* `{provider}`
*API Base:* `{api_base}`
*Errors:*
{self._count_outage_alerts(alerts=outage_value["alerts"])}
*Last Check:* `{round(time.time() - outage_value["last_updated_at"], 4)}s ago`\n\n
"""
# send minor alert
await self.send_alert(
message=msg, level="High", alert_type="outage_alerts"
)
# set to true
outage_value["major_alert_sent"] = True
## update cache ##
await self.internal_usage_cache.async_set_cache(
key=deployment_id, value=outage_value
)
except Exception as e:
pass
async def model_added_alert( async def model_added_alert(
self, model_name: str, litellm_model_name: str, passed_model_info: Any self, model_name: str, litellm_model_name: str, passed_model_info: Any
): ):
@ -745,10 +922,12 @@ Model Info:
``` ```
""" """
await self.send_alert( alert_val = self.send_alert(
message=message, level="Low", alert_type="new_model_added" message=message, level="Low", alert_type="new_model_added"
) )
pass
if alert_val is not None and asyncio.iscoroutine(alert_val):
await alert_val
async def model_removed_alert(self, model_name: str): async def model_removed_alert(self, model_name: str):
pass pass
@ -846,6 +1025,7 @@ Model Info:
"spend_reports", "spend_reports",
"new_model_added", "new_model_added",
"cooldown_deployment", "cooldown_deployment",
"outage_alerts",
], ],
user_info: Optional[WebhookEvent] = None, user_info: Optional[WebhookEvent] = None,
**kwargs, **kwargs,
@ -969,10 +1149,12 @@ Model Info:
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""Log failure + deployment latency""" """Log failure + deployment latency"""
_litellm_params = kwargs.get("litellm_params", {})
_model_info = _litellm_params.get("model_info", {}) or {}
model_id = _model_info.get("id", "")
try:
if "daily_reports" in self.alert_types: if "daily_reports" in self.alert_types:
model_id = ( try:
kwargs.get("litellm_params", {}).get("model_info", {}).get("id", "")
)
await self.async_update_daily_reports( await self.async_update_daily_reports(
DeploymentMetrics( DeploymentMetrics(
id=model_id, id=model_id,
@ -981,6 +1163,39 @@ Model Info:
updated_at=litellm.utils.get_utc_datetime(), updated_at=litellm.utils.get_utc_datetime(),
) )
) )
except Exception as e:
verbose_logger.debug(f"Exception raises -{str(e)}")
if "outage_alerts" in self.alert_types and isinstance(
kwargs.get("exception", ""), APIError
):
_litellm_params = litellm.types.router.LiteLLM_Params(
model=kwargs.get("model", ""),
**kwargs.get("litellm_params", {}),
**kwargs.get("optional_params", {}),
)
_region_name = litellm.utils._get_model_region(
custom_llm_provider=kwargs.get("custom_llm_provider", ""),
litellm_params=_litellm_params,
)
# if region name not known, default to api base #
if _region_name is None:
_region_name = litellm.get_api_base(
model=kwargs.get("model", ""),
optional_params={
**kwargs.get("litellm_params", {}),
**kwargs.get("optional_params", {}),
},
)
if _region_name is None:
_region_name = ""
await self.outage_alerts(
exception=kwargs["exception"],
deployment_id=model_id,
)
except Exception as e:
pass
async def _run_scheduler_helper(self, llm_router) -> bool: async def _run_scheduler_helper(self, llm_router) -> bool:
""" """

View file

@ -126,7 +126,7 @@ def convert_to_ollama_image(openai_image_url: str):
else: else:
base64_data = openai_image_url base64_data = openai_image_url
return base64_data; return base64_data
except Exception as e: except Exception as e:
if "Error: Unable to fetch image from URL" in str(e): if "Error: Unable to fetch image from URL" in str(e):
raise e raise e
@ -134,6 +134,7 @@ def convert_to_ollama_image(openai_image_url: str):
"""Image url not in expected format. Example Expected input - "image_url": "data:image/jpeg;base64,{base64_image}". """ """Image url not in expected format. Example Expected input - "image_url": "data:image/jpeg;base64,{base64_image}". """
) )
def ollama_pt( def ollama_pt(
model, messages model, messages
): # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template ): # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
@ -166,7 +167,9 @@ def ollama_pt(
if element["type"] == "text": if element["type"] == "text":
prompt += element["text"] prompt += element["text"]
elif element["type"] == "image_url": elif element["type"] == "image_url":
base64_image = convert_to_ollama_image(element["image_url"]["url"]) base64_image = convert_to_ollama_image(
element["image_url"]["url"]
)
images.append(base64_image) images.append(base64_image)
return {"prompt": prompt, "images": images} return {"prompt": prompt, "images": images}
else: else:
@ -1533,6 +1536,7 @@ def _gemini_vision_convert_messages(messages: list):
# Case 2: Base64 image data # Case 2: Base64 image data
import base64 import base64
import io import io
# Extract the base64 image data # Extract the base64 image data
base64_data = img.split("base64,")[1] base64_data = img.split("base64,")[1]

View file

@ -420,6 +420,8 @@ def mock_completion(
api_key="mock-key", api_key="mock-key",
) )
if isinstance(mock_response, Exception): if isinstance(mock_response, Exception):
if isinstance(mock_response, openai.APIError):
raise mock_response
raise litellm.APIError( raise litellm.APIError(
status_code=500, # type: ignore status_code=500, # type: ignore
message=str(mock_response), message=str(mock_response),
@ -463,7 +465,9 @@ def mock_completion(
return model_response return model_response
except: except Exception as e:
if isinstance(e, openai.APIError):
raise e
traceback.print_exc() traceback.print_exc()
raise Exception("Mock completion response failed") raise Exception("Mock completion response failed")
@ -864,6 +868,7 @@ def completion(
user=user, user=user,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
custom_llm_provider=custom_llm_provider,
) )
if mock_response: if mock_response:
return mock_completion( return mock_completion(

View file

@ -24,7 +24,7 @@ litellm_settings:
general_settings: general_settings:
alerting: ["slack"] alerting: ["slack"]
alerting_args: # alerting_args:
report_check_interval: 10 # report_check_interval: 10
enable_jwt_auth: True # enable_jwt_auth: True

View file

@ -3026,7 +3026,7 @@ class ProxyConfig:
general_settings["alert_types"] = _general_settings["alert_types"] general_settings["alert_types"] = _general_settings["alert_types"]
proxy_logging_obj.alert_types = general_settings["alert_types"] proxy_logging_obj.alert_types = general_settings["alert_types"]
proxy_logging_obj.slack_alerting_instance.update_values( proxy_logging_obj.slack_alerting_instance.update_values(
alert_types=general_settings["alert_types"] alert_types=general_settings["alert_types"], llm_router=llm_router
) )
if "alert_to_webhook_url" in _general_settings: if "alert_to_webhook_url" in _general_settings:
@ -3034,7 +3034,8 @@ class ProxyConfig:
"alert_to_webhook_url" "alert_to_webhook_url"
] ]
proxy_logging_obj.slack_alerting_instance.update_values( proxy_logging_obj.slack_alerting_instance.update_values(
alert_to_webhook_url=general_settings["alert_to_webhook_url"] alert_to_webhook_url=general_settings["alert_to_webhook_url"],
llm_router=llm_router,
) )
async def _update_general_settings(self, db_general_settings: Optional[Json]): async def _update_general_settings(self, db_general_settings: Optional[Json]):
@ -3602,6 +3603,9 @@ async def startup_event():
## Error Tracking ## ## Error Tracking ##
error_tracking() error_tracking()
## UPDATE SLACK ALERTING ##
proxy_logging_obj.slack_alerting_instance.update_values(llm_router=llm_router)
db_writer_client = HTTPHandler() db_writer_client = HTTPHandler()
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made

View file

@ -42,7 +42,7 @@ import smtplib, re
from email.mime.text import MIMEText from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from datetime import datetime, timedelta from datetime import datetime, timedelta
from litellm.integrations.slack_alerting import SlackAlerting from litellm.integrations.slack_alerting import SlackAlerting, AlertType
from typing_extensions import overload from typing_extensions import overload
@ -78,19 +78,7 @@ class ProxyLogging:
self.cache_control_check = _PROXY_CacheControlCheck() self.cache_control_check = _PROXY_CacheControlCheck()
self.alerting: Optional[List] = None self.alerting: Optional[List] = None
self.alerting_threshold: float = 300 # default to 5 min. threshold self.alerting_threshold: float = 300 # default to 5 min. threshold
self.alert_types: List[ self.alert_types: List[AlertType] = [
Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"cooldown_deployment",
"new_model_added",
]
] = [
"llm_exceptions", "llm_exceptions",
"llm_too_slow", "llm_too_slow",
"llm_requests_hanging", "llm_requests_hanging",
@ -100,6 +88,7 @@ class ProxyLogging:
"spend_reports", "spend_reports",
"cooldown_deployment", "cooldown_deployment",
"new_model_added", "new_model_added",
"outage_alerts",
] ]
self.slack_alerting_instance = SlackAlerting( self.slack_alerting_instance = SlackAlerting(
alerting_threshold=self.alerting_threshold, alerting_threshold=self.alerting_threshold,
@ -113,21 +102,7 @@ class ProxyLogging:
alerting: Optional[List], alerting: Optional[List],
alerting_threshold: Optional[float], alerting_threshold: Optional[float],
redis_cache: Optional[RedisCache], redis_cache: Optional[RedisCache],
alert_types: Optional[ alert_types: Optional[List[AlertType]] = None,
List[
Literal[
"llm_exceptions",
"llm_too_slow",
"llm_requests_hanging",
"budget_alerts",
"db_exceptions",
"daily_reports",
"spend_reports",
"cooldown_deployment",
"new_model_added",
]
]
] = None,
alerting_args: Optional[dict] = None, alerting_args: Optional[dict] = None,
): ):
self.alerting = alerting self.alerting = alerting

View file

@ -3876,13 +3876,13 @@ class Router:
_api_base = litellm.get_api_base( _api_base = litellm.get_api_base(
model=_model_name, optional_params=temp_litellm_params model=_model_name, optional_params=temp_litellm_params
) )
asyncio.create_task( # asyncio.create_task(
proxy_logging_obj.slack_alerting_instance.send_alert( # proxy_logging_obj.slack_alerting_instance.send_alert(
message=f"Router: Cooling down Deployment:\nModel Name: `{_model_name}`\nAPI Base: `{_api_base}`\nCooldown Time: `{cooldown_time} seconds`\nException Status Code: `{str(exception_status)}`\n\nChange 'cooldown_time' + 'allowed_fails' under 'Router Settings' on proxy UI, or via config - https://docs.litellm.ai/docs/proxy/reliability#fallbacks--retries--timeouts--cooldowns", # message=f"Router: Cooling down Deployment:\nModel Name: `{_model_name}`\nAPI Base: `{_api_base}`\nCooldown Time: `{cooldown_time} seconds`\nException Status Code: `{str(exception_status)}`\n\nChange 'cooldown_time' + 'allowed_fails' under 'Router Settings' on proxy UI, or via config - https://docs.litellm.ai/docs/proxy/reliability#fallbacks--retries--timeouts--cooldowns",
alert_type="cooldown_deployment", # alert_type="cooldown_deployment",
level="Low", # level="Low",
) # )
) # )
except Exception as e: except Exception as e:
pass pass

View file

@ -1,10 +1,11 @@
# What is this? # What is this?
## Tests slack alerting on proxy logging object ## Tests slack alerting on proxy logging object
import sys, json, uuid, random import sys, json, uuid, random, httpx
import os import os
import io, asyncio import io, asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional
# import logging # import logging
# logging.basicConfig(level=logging.DEBUG) # logging.basicConfig(level=logging.DEBUG)
@ -23,6 +24,7 @@ from unittest.mock import AsyncMock
import pytest import pytest
from litellm.router import AlertingConfig, Router from litellm.router import AlertingConfig, Router
from litellm.proxy._types import CallInfo from litellm.proxy._types import CallInfo
from openai import APIError
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -495,3 +497,109 @@ async def test_webhook_alerting(alerting_type):
user_info=user_info, user_info=user_info,
) )
mock_send_alert.assert_awaited_once() mock_send_alert.assert_awaited_once()
@pytest.mark.parametrize(
"model, api_base, llm_provider, vertex_project, vertex_location",
[
("gpt-3.5-turbo", None, "openai", None, None),
(
"azure/gpt-3.5-turbo",
"https://openai-gpt-4-test-v-1.openai.azure.com",
"azure",
None,
None,
),
("gemini-pro", None, "vertex_ai", "hardy-device-38811", "us-central1"),
],
)
@pytest.mark.parametrize("error_code", [500, 408, 400])
@pytest.mark.asyncio
async def test_outage_alerting_called(
model, api_base, llm_provider, vertex_project, vertex_location, error_code
):
"""
If call fails, outage alert is called
If multiple calls fail, outage alert is sent
"""
slack_alerting = SlackAlerting(alerting=["webhook"])
litellm.callbacks = [slack_alerting]
error_to_raise: Optional[APIError] = None
if error_code == 400:
print("RAISING 400 ERROR CODE")
error_to_raise = litellm.BadRequestError(
message="this is a bad request",
model=model,
llm_provider=llm_provider,
)
elif error_code == 408:
print("RAISING 408 ERROR CODE")
error_to_raise = litellm.Timeout(
message="A timeout occurred", model=model, llm_provider=llm_provider
)
elif error_code == 500:
print("RAISING 500 ERROR CODE")
error_to_raise = litellm.ServiceUnavailableError(
message="API is unavailable",
model=model,
llm_provider=llm_provider,
response=httpx.Response(
status_code=503,
request=httpx.Request(
method="completion",
url="https://github.com/BerriAI/litellm",
),
),
)
router = Router(
model_list=[
{
"model_name": model,
"litellm_params": {
"model": model,
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": api_base,
"vertex_location": vertex_location,
"vertex_project": vertex_project,
},
}
],
num_retries=0,
allowed_fails=100,
)
slack_alerting.update_values(llm_router=router)
with patch.object(
slack_alerting, "outage_alerts", new=AsyncMock()
) as mock_send_alert:
try:
await router.acompletion(
model=model,
messages=[{"role": "user", "content": "Hey!"}],
mock_response=error_to_raise,
)
except Exception as e:
pass
mock_send_alert.assert_called_once()
with patch.object(slack_alerting, "send_alert", new=AsyncMock()) as mock_send_alert:
for _ in range(3):
try:
await router.acompletion(
model=model,
messages=[{"role": "user", "content": "Hey!"}],
mock_response=error_to_raise,
)
except Exception as e:
pass
await asyncio.sleep(3)
if error_code == 500 or error_code == 408:
mock_send_alert.assert_called_once()
else:
mock_send_alert.assert_not_called()

View file

@ -6298,7 +6298,9 @@ def get_model_region(
return None return None
def get_api_base(model: str, optional_params: dict) -> Optional[str]: def get_api_base(
model: str, optional_params: Union[dict, LiteLLM_Params]
) -> Optional[str]:
""" """
Returns the api base used for calling the model. Returns the api base used for calling the model.
@ -6318,7 +6320,9 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]:
""" """
try: try:
if "model" in optional_params: if isinstance(optional_params, LiteLLM_Params):
_optional_params = optional_params
elif "model" in optional_params:
_optional_params = LiteLLM_Params(**optional_params) _optional_params = LiteLLM_Params(**optional_params)
else: # prevent needing to copy and pop the dict else: # prevent needing to copy and pop the dict
_optional_params = LiteLLM_Params( _optional_params = LiteLLM_Params(
@ -6711,6 +6715,8 @@ def get_llm_provider(
Returns the provider for a given model name - e.g. 'azure/chatgpt-v-2' -> 'azure' Returns the provider for a given model name - e.g. 'azure/chatgpt-v-2' -> 'azure'
For router -> Can also give the whole litellm param dict -> this function will extract the relevant details For router -> Can also give the whole litellm param dict -> this function will extract the relevant details
Raises Error - if unable to map model to a provider
""" """
try: try:
## IF LITELLM PARAMS GIVEN ## ## IF LITELLM PARAMS GIVEN ##
@ -8644,7 +8650,16 @@ def exception_type(
) )
elif hasattr(original_exception, "status_code"): elif hasattr(original_exception, "status_code"):
exception_mapping_worked = True exception_mapping_worked = True
if original_exception.status_code == 401: if original_exception.status_code == 400:
exception_mapping_worked = True
raise BadRequestError(
message=f"{exception_provider} - {message}",
llm_provider=custom_llm_provider,
model=model,
response=original_exception.response,
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 401:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError( raise AuthenticationError(
message=f"{exception_provider} - {message}", message=f"{exception_provider} - {message}",
@ -9157,6 +9172,7 @@ def exception_type(
), ),
), ),
) )
if hasattr(original_exception, "status_code"): if hasattr(original_exception, "status_code"):
if original_exception.status_code == 400: if original_exception.status_code == 400:
exception_mapping_worked = True exception_mapping_worked = True
@ -9837,7 +9853,16 @@ def exception_type(
) )
elif hasattr(original_exception, "status_code"): elif hasattr(original_exception, "status_code"):
exception_mapping_worked = True exception_mapping_worked = True
if original_exception.status_code == 401: if original_exception.status_code == 400:
exception_mapping_worked = True
raise BadRequestError(
message=f"AzureException - {original_exception.message}",
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
response=original_exception.response,
)
elif original_exception.status_code == 401:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError( raise AuthenticationError(
message=f"AzureException - {original_exception.message}", message=f"AzureException - {original_exception.message}",
@ -9854,7 +9879,7 @@ def exception_type(
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
llm_provider="azure", llm_provider="azure",
) )
if original_exception.status_code == 422: elif original_exception.status_code == 422:
exception_mapping_worked = True exception_mapping_worked = True
raise BadRequestError( raise BadRequestError(
message=f"AzureException - {original_exception.message}", message=f"AzureException - {original_exception.message}",