fix(utils.py): add exception mapping for databricks errors

This commit is contained in:
Krrish Dholakia 2024-07-27 13:13:31 -07:00
parent 1a8f45e8da
commit 05ba34b9b7
2 changed files with 22 additions and 17 deletions

View file

@ -770,7 +770,9 @@ def test_litellm_predibase_exception():
# print(f"accuracy_score: {accuracy_score}") # print(f"accuracy_score: {accuracy_score}")
@pytest.mark.parametrize("provider", ["predibase", "vertex_ai_beta", "anthropic"]) @pytest.mark.parametrize(
"provider", ["predibase", "vertex_ai_beta", "anthropic", "databricks"]
)
def test_exception_mapping(provider): def test_exception_mapping(provider):
""" """
For predibase, run through a set of mock exceptions For predibase, run through a set of mock exceptions

View file

@ -6723,7 +6723,10 @@ def exception_type(
model=model, model=model,
response=original_exception.response, response=original_exception.response,
) )
elif custom_llm_provider == "predibase": elif (
custom_llm_provider == "predibase"
or custom_llm_provider == "databricks"
):
if "authorization denied for" in error_str: if "authorization denied for" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
@ -6739,8 +6742,8 @@ def exception_type(
error_str += "XXXXXXX" + '"' error_str += "XXXXXXX" + '"'
raise AuthenticationError( raise AuthenticationError(
message=f"PredibaseException: Authentication Error - {error_str}", message=f"{custom_llm_provider}Exception: Authentication Error - {error_str}",
llm_provider="predibase", llm_provider=custom_llm_provider,
model=model, model=model,
response=original_exception.response, response=original_exception.response,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
@ -6749,35 +6752,35 @@ def exception_type(
if original_exception.status_code == 500: if original_exception.status_code == 500:
exception_mapping_worked = True exception_mapping_worked = True
raise litellm.InternalServerError( raise litellm.InternalServerError(
message=f"PredibaseException - {original_exception.message}", message=f"{custom_llm_provider}Exception - {original_exception.message}",
llm_provider="predibase", llm_provider=custom_llm_provider,
model=model, model=model,
) )
elif original_exception.status_code == 401: elif original_exception.status_code == 401:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError( raise AuthenticationError(
message=f"PredibaseException - {original_exception.message}", message=f"{custom_llm_provider}Exception - {original_exception.message}",
llm_provider="predibase", llm_provider=custom_llm_provider,
model=model, model=model,
) )
elif original_exception.status_code == 400: elif original_exception.status_code == 400:
exception_mapping_worked = True exception_mapping_worked = True
raise BadRequestError( raise BadRequestError(
message=f"PredibaseException - {original_exception.message}", message=f"{custom_llm_provider}Exception - {original_exception.message}",
llm_provider="predibase", llm_provider=custom_llm_provider,
model=model, model=model,
) )
elif original_exception.status_code == 404: elif original_exception.status_code == 404:
exception_mapping_worked = True exception_mapping_worked = True
raise NotFoundError( raise NotFoundError(
message=f"PredibaseException - {original_exception.message}", message=f"{custom_llm_provider}Exception - {original_exception.message}",
llm_provider="predibase", llm_provider=custom_llm_provider,
model=model, model=model,
) )
elif original_exception.status_code == 408: elif original_exception.status_code == 408:
exception_mapping_worked = True exception_mapping_worked = True
raise Timeout( raise Timeout(
message=f"PredibaseException - {original_exception.message}", message=f"{custom_llm_provider}Exception - {original_exception.message}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
@ -6788,7 +6791,7 @@ def exception_type(
): ):
exception_mapping_worked = True exception_mapping_worked = True
raise BadRequestError( raise BadRequestError(
message=f"PredibaseException - {original_exception.message}", message=f"{custom_llm_provider}Exception - {original_exception.message}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
@ -6796,7 +6799,7 @@ def exception_type(
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
exception_mapping_worked = True exception_mapping_worked = True
raise RateLimitError( raise RateLimitError(
message=f"PredibaseException - {original_exception.message}", message=f"{custom_llm_provider}Exception - {original_exception.message}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
@ -6804,7 +6807,7 @@ def exception_type(
elif original_exception.status_code == 503: elif original_exception.status_code == 503:
exception_mapping_worked = True exception_mapping_worked = True
raise ServiceUnavailableError( raise ServiceUnavailableError(
message=f"PredibaseException - {original_exception.message}", message=f"{custom_llm_provider}Exception - {original_exception.message}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
@ -6812,7 +6815,7 @@ def exception_type(
elif original_exception.status_code == 504: # gateway timeout error elif original_exception.status_code == 504: # gateway timeout error
exception_mapping_worked = True exception_mapping_worked = True
raise Timeout( raise Timeout(
message=f"PredibaseException - {original_exception.message}", message=f"{custom_llm_provider}Exception - {original_exception.message}",
model=model, model=model,
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,