mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix(bedrock.py): add exception mapping coverage for authentication scenarios
This commit is contained in:
parent
142750adff
commit
1c4dd0671b
3 changed files with 29 additions and 8 deletions
|
@ -190,7 +190,7 @@ def init_bedrock_client(
|
|||
elif standard_aws_region_name:
|
||||
region_name = standard_aws_region_name
|
||||
else:
|
||||
raise BedrockError(message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file")
|
||||
raise BedrockError(message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", status_code=401)
|
||||
|
||||
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
|
||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||
|
@ -397,18 +397,18 @@ def completion(
|
|||
outputText = response_body.get('results')[0].get('outputText')
|
||||
|
||||
response_metadata = response.get("ResponseMetadata", {})
|
||||
|
||||
print(f"response_metadata: {response_metadata}")
|
||||
if response_metadata.get("HTTPStatusCode", 500) >= 400:
|
||||
raise BedrockError(
|
||||
message=outputText,
|
||||
status_code=response.get("HTTPStatusCode", 500),
|
||||
status_code=response_metadata.get("HTTPStatusCode", 500),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
if len(outputText) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = outputText
|
||||
except:
|
||||
raise BedrockError(message=json.dumps(outputText), status_code=response.status_code)
|
||||
raise BedrockError(message=json.dumps(outputText), status_code=response_metadata.get("HTTPStatusCode", 500))
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(
|
||||
|
|
|
@ -73,6 +73,13 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
|
|||
if model == "gpt-3.5-turbo":
|
||||
temporary_key = os.environ["OPENAI_API_KEY"]
|
||||
os.environ["OPENAI_API_KEY"] = "bad-key"
|
||||
elif model == "bedrock/anthropic.claude-v2":
|
||||
temporary_aws_access_key = os.environ["AWS_ACCESS_KEY_ID"]
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = "bad-key"
|
||||
temporary_aws_region_name = os.environ["AWS_REGION_NAME"]
|
||||
os.environ["AWS_REGION_NAME"] = "bad-key"
|
||||
temporary_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"]
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = "bad-key"
|
||||
elif model == "chatgpt-test":
|
||||
temporary_key = os.environ["AZURE_API_KEY"]
|
||||
os.environ["AZURE_API_KEY"] = "bad-key"
|
||||
|
@ -109,10 +116,10 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
|
|||
)
|
||||
print(f"response: {response}")
|
||||
except AuthenticationError as e:
|
||||
print(f"AuthenticationError Caught Exception - {e.llm_provider}")
|
||||
print(f"AuthenticationError Caught Exception - {str(e)}")
|
||||
except (
|
||||
OpenAIError
|
||||
): # is at least an openai error -> in case of random model errors - e.g. overloaded server
|
||||
) as e: # is at least an openai error -> in case of random model errors - e.g. overloaded server
|
||||
print(f"OpenAIError Caught Exception - {e}")
|
||||
except Exception as e:
|
||||
print(type(e))
|
||||
|
@ -143,8 +150,15 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
|
|||
os.environ["ALEPH_ALPHA_API_KEY"] = temporary_key
|
||||
elif model in litellm.nlp_cloud_models:
|
||||
os.environ["NLP_CLOUD_API_KEY"] = temporary_key
|
||||
elif "bedrock" in model:
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = temporary_aws_access_key
|
||||
os.environ["AWS_REGION_NAME"] = temporary_aws_region_name
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = temporary_secret_key
|
||||
return
|
||||
|
||||
for model in litellm.models_by_provider["bedrock"]:
|
||||
invalid_auth(model=model)
|
||||
|
||||
# Test 3: Invalid Request Error
|
||||
@pytest.mark.parametrize("model", models)
|
||||
def test_invalid_request_error(model):
|
||||
|
|
|
@ -2946,14 +2946,14 @@ def exception_type(
|
|||
model=model,
|
||||
llm_provider="bedrock"
|
||||
)
|
||||
if "Unable to locate credentials" in error_str or "Malformed input request" in error_str:
|
||||
if "Malformed input request" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise InvalidRequestError(
|
||||
message=f"BedrockException - {error_str}",
|
||||
model=model,
|
||||
llm_provider="bedrock"
|
||||
)
|
||||
if "The security token included in the request is invalid" in error_str:
|
||||
if "Unable to locate credentials" in error_str or "The security token included in the request is invalid" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise AuthenticationError(
|
||||
message=f"BedrockException Invalid Authentication - {error_str}",
|
||||
|
@ -2975,6 +2975,13 @@ def exception_type(
|
|||
llm_provider="bedrock",
|
||||
model=model
|
||||
)
|
||||
elif original_exception.status_code == 401:
|
||||
exception_mapping_worked = True
|
||||
raise AuthenticationError(
|
||||
message=f"BedrockException - {original_exception.message}",
|
||||
llm_provider="bedrock",
|
||||
model=model
|
||||
)
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
if "Unable to locate credentials" in error_str:
|
||||
exception_mapping_worked = True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue