fix(bedrock.py): add exception mapping coverage for authentication scenarios

This commit is contained in:
Krrish Dholakia 2023-11-03 18:25:34 -07:00
parent 142750adff
commit 1c4dd0671b
3 changed files with 29 additions and 8 deletions

View file

@ -190,7 +190,7 @@ def init_bedrock_client(
elif standard_aws_region_name: elif standard_aws_region_name:
region_name = standard_aws_region_name region_name = standard_aws_region_name
else: else:
raise BedrockError(message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file") raise BedrockError(message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", status_code=401)
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client # check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
@ -397,18 +397,18 @@ def completion(
outputText = response_body.get('results')[0].get('outputText') outputText = response_body.get('results')[0].get('outputText')
response_metadata = response.get("ResponseMetadata", {}) response_metadata = response.get("ResponseMetadata", {})
print(f"response_metadata: {response_metadata}")
if response_metadata.get("HTTPStatusCode", 500) >= 400: if response_metadata.get("HTTPStatusCode", 500) >= 400:
raise BedrockError( raise BedrockError(
message=outputText, message=outputText,
status_code=response.get("HTTPStatusCode", 500), status_code=response_metadata.get("HTTPStatusCode", 500),
) )
else: else:
try: try:
if len(outputText) > 0: if len(outputText) > 0:
model_response["choices"][0]["message"]["content"] = outputText model_response["choices"][0]["message"]["content"] = outputText
except: except:
raise BedrockError(message=json.dumps(outputText), status_code=response.status_code) raise BedrockError(message=json.dumps(outputText), status_code=response_metadata.get("HTTPStatusCode", 500))
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len( prompt_tokens = len(

View file

@ -73,6 +73,13 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
if model == "gpt-3.5-turbo": if model == "gpt-3.5-turbo":
temporary_key = os.environ["OPENAI_API_KEY"] temporary_key = os.environ["OPENAI_API_KEY"]
os.environ["OPENAI_API_KEY"] = "bad-key" os.environ["OPENAI_API_KEY"] = "bad-key"
elif model == "bedrock/anthropic.claude-v2":
temporary_aws_access_key = os.environ["AWS_ACCESS_KEY_ID"]
os.environ["AWS_ACCESS_KEY_ID"] = "bad-key"
temporary_aws_region_name = os.environ["AWS_REGION_NAME"]
os.environ["AWS_REGION_NAME"] = "bad-key"
temporary_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"]
os.environ["AWS_SECRET_ACCESS_KEY"] = "bad-key"
elif model == "chatgpt-test": elif model == "chatgpt-test":
temporary_key = os.environ["AZURE_API_KEY"] temporary_key = os.environ["AZURE_API_KEY"]
os.environ["AZURE_API_KEY"] = "bad-key" os.environ["AZURE_API_KEY"] = "bad-key"
@ -109,10 +116,10 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
) )
print(f"response: {response}") print(f"response: {response}")
except AuthenticationError as e: except AuthenticationError as e:
print(f"AuthenticationError Caught Exception - {e.llm_provider}") print(f"AuthenticationError Caught Exception - {str(e)}")
except ( except (
OpenAIError OpenAIError
): # is at least an openai error -> in case of random model errors - e.g. overloaded server ) as e: # is at least an openai error -> in case of random model errors - e.g. overloaded server
print(f"OpenAIError Caught Exception - {e}") print(f"OpenAIError Caught Exception - {e}")
except Exception as e: except Exception as e:
print(type(e)) print(type(e))
@ -143,8 +150,15 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
os.environ["ALEPH_ALPHA_API_KEY"] = temporary_key os.environ["ALEPH_ALPHA_API_KEY"] = temporary_key
elif model in litellm.nlp_cloud_models: elif model in litellm.nlp_cloud_models:
os.environ["NLP_CLOUD_API_KEY"] = temporary_key os.environ["NLP_CLOUD_API_KEY"] = temporary_key
elif "bedrock" in model:
os.environ["AWS_ACCESS_KEY_ID"] = temporary_aws_access_key
os.environ["AWS_REGION_NAME"] = temporary_aws_region_name
os.environ["AWS_SECRET_ACCESS_KEY"] = temporary_secret_key
return return
for model in litellm.models_by_provider["bedrock"]:
invalid_auth(model=model)
# Test 3: Invalid Request Error # Test 3: Invalid Request Error
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_invalid_request_error(model): def test_invalid_request_error(model):

View file

@ -2946,14 +2946,14 @@ def exception_type(
model=model, model=model,
llm_provider="bedrock" llm_provider="bedrock"
) )
if "Unable to locate credentials" in error_str or "Malformed input request" in error_str: if "Malformed input request" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise InvalidRequestError( raise InvalidRequestError(
message=f"BedrockException - {error_str}", message=f"BedrockException - {error_str}",
model=model, model=model,
llm_provider="bedrock" llm_provider="bedrock"
) )
if "The security token included in the request is invalid" in error_str: if "Unable to locate credentials" in error_str or "The security token included in the request is invalid" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError( raise AuthenticationError(
message=f"BedrockException Invalid Authentication - {error_str}", message=f"BedrockException Invalid Authentication - {error_str}",
@ -2975,6 +2975,13 @@ def exception_type(
llm_provider="bedrock", llm_provider="bedrock",
model=model model=model
) )
elif original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(
message=f"BedrockException - {original_exception.message}",
llm_provider="bedrock",
model=model
)
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
if "Unable to locate credentials" in error_str: if "Unable to locate credentials" in error_str:
exception_mapping_worked = True exception_mapping_worked = True