From 05ba34b9b719bac16b3671bf4602fdb97dbc069e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 27 Jul 2024 13:13:31 -0700 Subject: [PATCH] fix(utils.py): add exception mapping for databricks errors --- litellm/tests/test_exceptions.py | 4 +++- litellm/utils.py | 35 +++++++++++++++++--------------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/litellm/tests/test_exceptions.py b/litellm/tests/test_exceptions.py index 66c8594bb..dfefe99d6 100644 --- a/litellm/tests/test_exceptions.py +++ b/litellm/tests/test_exceptions.py @@ -770,7 +770,9 @@ def test_litellm_predibase_exception(): # print(f"accuracy_score: {accuracy_score}") -@pytest.mark.parametrize("provider", ["predibase", "vertex_ai_beta", "anthropic"]) +@pytest.mark.parametrize( + "provider", ["predibase", "vertex_ai_beta", "anthropic", "databricks"] +) def test_exception_mapping(provider): """ For predibase, run through a set of mock exceptions diff --git a/litellm/utils.py b/litellm/utils.py index 358904677..780148059 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6723,7 +6723,10 @@ def exception_type( model=model, response=original_exception.response, ) - elif custom_llm_provider == "predibase": + elif ( + custom_llm_provider == "predibase" + or custom_llm_provider == "databricks" + ): if "authorization denied for" in error_str: exception_mapping_worked = True @@ -6739,8 +6742,8 @@ def exception_type( error_str += "XXXXXXX" + '"' raise AuthenticationError( - message=f"PredibaseException: Authentication Error - {error_str}", - llm_provider="predibase", + message=f"{custom_llm_provider}Exception: Authentication Error - {error_str}", + llm_provider=custom_llm_provider, model=model, response=original_exception.response, litellm_debug_info=extra_information, @@ -6749,35 +6752,35 @@ def exception_type( if original_exception.status_code == 500: exception_mapping_worked = True raise litellm.InternalServerError( - message=f"PredibaseException - {original_exception.message}", - llm_provider="predibase", + message=f"{custom_llm_provider}Exception - {original_exception.message}", + llm_provider=custom_llm_provider, model=model, ) elif original_exception.status_code == 401: exception_mapping_worked = True raise AuthenticationError( - message=f"PredibaseException - {original_exception.message}", - llm_provider="predibase", + message=f"{custom_llm_provider}Exception - {original_exception.message}", + llm_provider=custom_llm_provider, model=model, ) elif original_exception.status_code == 400: exception_mapping_worked = True raise BadRequestError( - message=f"PredibaseException - {original_exception.message}", - llm_provider="predibase", + message=f"{custom_llm_provider}Exception - {original_exception.message}", + llm_provider=custom_llm_provider, model=model, ) elif original_exception.status_code == 404: exception_mapping_worked = True raise NotFoundError( - message=f"PredibaseException - {original_exception.message}", - llm_provider="predibase", + message=f"{custom_llm_provider}Exception - {original_exception.message}", + llm_provider=custom_llm_provider, model=model, ) elif original_exception.status_code == 408: exception_mapping_worked = True raise Timeout( - message=f"PredibaseException - {original_exception.message}", + message=f"{custom_llm_provider}Exception - {original_exception.message}", model=model, llm_provider=custom_llm_provider, litellm_debug_info=extra_information, @@ -6788,7 +6791,7 @@ def exception_type( ): exception_mapping_worked = True raise BadRequestError( - message=f"PredibaseException - {original_exception.message}", + message=f"{custom_llm_provider}Exception - {original_exception.message}", model=model, llm_provider=custom_llm_provider, litellm_debug_info=extra_information, @@ -6796,7 +6799,7 @@ def exception_type( elif original_exception.status_code == 429: exception_mapping_worked = True raise RateLimitError( - message=f"PredibaseException - {original_exception.message}", + message=f"{custom_llm_provider}Exception - {original_exception.message}", model=model, llm_provider=custom_llm_provider, litellm_debug_info=extra_information, @@ -6804,7 +6807,7 @@ def exception_type( elif original_exception.status_code == 503: exception_mapping_worked = True raise ServiceUnavailableError( - message=f"PredibaseException - {original_exception.message}", + message=f"{custom_llm_provider}Exception - {original_exception.message}", model=model, llm_provider=custom_llm_provider, litellm_debug_info=extra_information, @@ -6812,7 +6815,7 @@ def exception_type( elif original_exception.status_code == 504: # gateway timeout error exception_mapping_worked = True raise Timeout( - message=f"PredibaseException - {original_exception.message}", + message=f"{custom_llm_provider}Exception - {original_exception.message}", model=model, llm_provider=custom_llm_provider, litellm_debug_info=extra_information,