fix(bedrock.py): add exception mapping coverage for authentication scenarios

This commit is contained in:
Krrish Dholakia 2023-11-03 18:25:34 -07:00
parent 142750adff
commit 1c4dd0671b
3 changed files with 29 additions and 8 deletions

View file

@ -73,6 +73,13 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
if model == "gpt-3.5-turbo":
temporary_key = os.environ["OPENAI_API_KEY"]
os.environ["OPENAI_API_KEY"] = "bad-key"
elif model == "bedrock/anthropic.claude-v2":
temporary_aws_access_key = os.environ["AWS_ACCESS_KEY_ID"]
os.environ["AWS_ACCESS_KEY_ID"] = "bad-key"
temporary_aws_region_name = os.environ["AWS_REGION_NAME"]
os.environ["AWS_REGION_NAME"] = "bad-key"
temporary_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"]
os.environ["AWS_SECRET_ACCESS_KEY"] = "bad-key"
elif model == "chatgpt-test":
temporary_key = os.environ["AZURE_API_KEY"]
os.environ["AZURE_API_KEY"] = "bad-key"
@ -109,10 +116,10 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
)
print(f"response: {response}")
except AuthenticationError as e:
print(f"AuthenticationError Caught Exception - {e.llm_provider}")
print(f"AuthenticationError Caught Exception - {str(e)}")
except (
OpenAIError
): # is at least an openai error -> in case of random model errors - e.g. overloaded server
) as e: # is at least an openai error -> in case of random model errors - e.g. overloaded server
print(f"OpenAIError Caught Exception - {e}")
except Exception as e:
print(type(e))
@ -143,8 +150,15 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
os.environ["ALEPH_ALPHA_API_KEY"] = temporary_key
elif model in litellm.nlp_cloud_models:
os.environ["NLP_CLOUD_API_KEY"] = temporary_key
elif "bedrock" in model:
os.environ["AWS_ACCESS_KEY_ID"] = temporary_aws_access_key
os.environ["AWS_REGION_NAME"] = temporary_aws_region_name
os.environ["AWS_SECRET_ACCESS_KEY"] = temporary_secret_key
return
for model in litellm.models_by_provider["bedrock"]:
invalid_auth(model=model)
# Test 3: Invalid Request Error
@pytest.mark.parametrize("model", models)
def test_invalid_request_error(model):