fix(bedrock.py): add exception mapping coverage for authentication scenarios

This commit is contained in:
Krrish Dholakia 2023-11-03 18:25:34 -07:00
parent 142750adff
commit 1c4dd0671b
3 changed files with 29 additions and 8 deletions

View file

@ -190,7 +190,7 @@ def init_bedrock_client(
elif standard_aws_region_name:
region_name = standard_aws_region_name
else:
raise BedrockError(message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file")
raise BedrockError(message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", status_code=401)
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
@ -397,18 +397,18 @@ def completion(
outputText = response_body.get('results')[0].get('outputText')
response_metadata = response.get("ResponseMetadata", {})
print(f"response_metadata: {response_metadata}")
if response_metadata.get("HTTPStatusCode", 500) >= 400:
raise BedrockError(
message=outputText,
status_code=response.get("HTTPStatusCode", 500),
status_code=response_metadata.get("HTTPStatusCode", 500),
)
else:
try:
if len(outputText) > 0:
model_response["choices"][0]["message"]["content"] = outputText
except:
raise BedrockError(message=json.dumps(outputText), status_code=response.status_code)
raise BedrockError(message=json.dumps(outputText), status_code=response_metadata.get("HTTPStatusCode", 500))
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(

View file

@ -73,6 +73,13 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
if model == "gpt-3.5-turbo":
temporary_key = os.environ["OPENAI_API_KEY"]
os.environ["OPENAI_API_KEY"] = "bad-key"
elif model == "bedrock/anthropic.claude-v2":
temporary_aws_access_key = os.environ["AWS_ACCESS_KEY_ID"]
os.environ["AWS_ACCESS_KEY_ID"] = "bad-key"
temporary_aws_region_name = os.environ["AWS_REGION_NAME"]
os.environ["AWS_REGION_NAME"] = "bad-key"
temporary_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"]
os.environ["AWS_SECRET_ACCESS_KEY"] = "bad-key"
elif model == "chatgpt-test":
temporary_key = os.environ["AZURE_API_KEY"]
os.environ["AZURE_API_KEY"] = "bad-key"
@ -109,10 +116,10 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
)
print(f"response: {response}")
except AuthenticationError as e:
print(f"AuthenticationError Caught Exception - {e.llm_provider}")
print(f"AuthenticationError Caught Exception - {str(e)}")
except (
OpenAIError
): # is at least an openai error -> in case of random model errors - e.g. overloaded server
) as e: # is at least an openai error -> in case of random model errors - e.g. overloaded server
print(f"OpenAIError Caught Exception - {e}")
except Exception as e:
print(type(e))
@ -143,8 +150,15 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
os.environ["ALEPH_ALPHA_API_KEY"] = temporary_key
elif model in litellm.nlp_cloud_models:
os.environ["NLP_CLOUD_API_KEY"] = temporary_key
elif "bedrock" in model:
os.environ["AWS_ACCESS_KEY_ID"] = temporary_aws_access_key
os.environ["AWS_REGION_NAME"] = temporary_aws_region_name
os.environ["AWS_SECRET_ACCESS_KEY"] = temporary_secret_key
return
for model in litellm.models_by_provider["bedrock"]:
invalid_auth(model=model)
# Test 3: Invalid Request Error
@pytest.mark.parametrize("model", models)
def test_invalid_request_error(model):

View file

@ -2946,14 +2946,14 @@ def exception_type(
model=model,
llm_provider="bedrock"
)
if "Unable to locate credentials" in error_str or "Malformed input request" in error_str:
if "Malformed input request" in error_str:
exception_mapping_worked = True
raise InvalidRequestError(
message=f"BedrockException - {error_str}",
model=model,
llm_provider="bedrock"
)
if "The security token included in the request is invalid" in error_str:
if "Unable to locate credentials" in error_str or "The security token included in the request is invalid" in error_str:
exception_mapping_worked = True
raise AuthenticationError(
message=f"BedrockException Invalid Authentication - {error_str}",
@ -2975,6 +2975,13 @@ def exception_type(
llm_provider="bedrock",
model=model
)
elif original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(
message=f"BedrockException - {original_exception.message}",
llm_provider="bedrock",
model=model
)
elif custom_llm_provider == "sagemaker":
if "Unable to locate credentials" in error_str:
exception_mapping_worked = True