Merge pull request #3662 from BerriAI/litellm_feat_predibase_exceptions

[Fix] Mask API Keys from Predibase AuthenticationErrors
This commit is contained in:
Ishaan Jaff 2024-05-15 20:45:40 -07:00 committed by GitHub
commit 7aac76b485
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 73 additions and 25 deletions

View file

@ -596,6 +596,26 @@ def test_litellm_completion_vertex_exception():
print("exception: ", e) print("exception: ", e)
def test_litellm_predibase_exception():
"""
Test - Assert that the Predibase API Key is not returned on Authentication Errors
"""
try:
import litellm
litellm.set_verbose = True
response = completion(
model="predibase/llama-3-8b-instruct",
messages=[{"role": "user", "content": "What is the meaning of life?"}],
tenant_id="c4768f95",
api_key="hf-rawapikey",
)
pytest.fail("Request should have failed - bad api key")
except Exception as e:
assert "hf-rawapikey" not in str(e)
print("exception: ", e)
# # test_invalid_request_error(model="command-nightly") # # test_invalid_request_error(model="command-nightly")
# # Test 3: Rate Limit Errors # # Test 3: Rate Limit Errors
# def test_model_call(model): # def test_model_call(model):

View file

@ -8141,33 +8141,39 @@ def exception_type(
# Common Extra information needed for all providers # Common Extra information needed for all providers
# We pass num retries, api_base, vertex_deployment etc to the exception here # We pass num retries, api_base, vertex_deployment etc to the exception here
################################################################################ ################################################################################
extra_information = ""
try:
_api_base = litellm.get_api_base(
model=model, optional_params=extra_kwargs
)
messages = litellm.get_first_chars_messages(kwargs=completion_kwargs)
_vertex_project = extra_kwargs.get("vertex_project")
_vertex_location = extra_kwargs.get("vertex_location")
_metadata = extra_kwargs.get("metadata", {}) or {}
_model_group = _metadata.get("model_group")
_deployment = _metadata.get("deployment")
extra_information = f"\nModel: {model}"
if _api_base:
extra_information += f"\nAPI Base: {_api_base}"
if messages and len(messages) > 0:
extra_information += f"\nMessages: {messages}"
_api_base = litellm.get_api_base(model=model, optional_params=extra_kwargs) if _model_group is not None:
messages = litellm.get_first_chars_messages(kwargs=completion_kwargs) extra_information += f"\nmodel_group: {_model_group}\n"
_vertex_project = extra_kwargs.get("vertex_project") if _deployment is not None:
_vertex_location = extra_kwargs.get("vertex_location") extra_information += f"\ndeployment: {_deployment}\n"
_metadata = extra_kwargs.get("metadata", {}) or {} if _vertex_project is not None:
_model_group = _metadata.get("model_group") extra_information += f"\nvertex_project: {_vertex_project}\n"
_deployment = _metadata.get("deployment") if _vertex_location is not None:
extra_information = f"\nModel: {model}" extra_information += f"\nvertex_location: {_vertex_location}\n"
if _api_base:
extra_information += f"\nAPI Base: {_api_base}"
if messages and len(messages) > 0:
extra_information += f"\nMessages: {messages}"
if _model_group is not None: # on litellm proxy add key name + team to exceptions
extra_information += f"\nmodel_group: {_model_group}\n" extra_information = _add_key_name_and_team_to_alert(
if _deployment is not None: request_info=extra_information, metadata=_metadata
extra_information += f"\ndeployment: {_deployment}\n" )
if _vertex_project is not None: except:
extra_information += f"\nvertex_project: {_vertex_project}\n" # DO NOT LET this Block raising the original exception
if _vertex_location is not None: pass
extra_information += f"\nvertex_location: {_vertex_location}\n"
# on litellm proxy add key name + team to exceptions
extra_information = _add_key_name_and_team_to_alert(
request_info=extra_information, metadata=_metadata
)
################################################################################ ################################################################################
# End of Common Extra information Needed for all providers # End of Common Extra information Needed for all providers
@ -8532,6 +8538,28 @@ def exception_type(
model=model, model=model,
response=original_exception.response, response=original_exception.response,
) )
elif custom_llm_provider == "predibase":
if "authorization denied for" in error_str:
exception_mapping_worked = True
# Predibase returns the raw API Key in the response - this block ensures it's not returned in the exception
if (
error_str is not None
and isinstance(error_str, str)
and "bearer" in error_str.lower()
):
# only keep the first 10 chars after the occurnence of "bearer"
_bearer_token_start_index = error_str.lower().find("bearer")
error_str = error_str[: _bearer_token_start_index + 14]
error_str += "XXXXXXX" + '"'
raise AuthenticationError(
message=f"PredibaseException: Authentication Error - {error_str}",
llm_provider="predibase",
model=model,
response=original_exception.response,
litellm_debug_info=extra_information,
)
elif custom_llm_provider == "bedrock": elif custom_llm_provider == "bedrock":
if ( if (
"too many tokens" in error_str "too many tokens" in error_str