mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(bedrock.py): add exception mapping coverage for authentication scenarios
This commit is contained in:
parent
142750adff
commit
1c4dd0671b
3 changed files with 29 additions and 8 deletions
|
@ -190,7 +190,7 @@ def init_bedrock_client(
|
||||||
elif standard_aws_region_name:
|
elif standard_aws_region_name:
|
||||||
region_name = standard_aws_region_name
|
region_name = standard_aws_region_name
|
||||||
else:
|
else:
|
||||||
raise BedrockError(message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file")
|
raise BedrockError(message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", status_code=401)
|
||||||
|
|
||||||
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
|
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
|
||||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||||
|
@ -397,18 +397,18 @@ def completion(
|
||||||
outputText = response_body.get('results')[0].get('outputText')
|
outputText = response_body.get('results')[0].get('outputText')
|
||||||
|
|
||||||
response_metadata = response.get("ResponseMetadata", {})
|
response_metadata = response.get("ResponseMetadata", {})
|
||||||
|
print(f"response_metadata: {response_metadata}")
|
||||||
if response_metadata.get("HTTPStatusCode", 500) >= 400:
|
if response_metadata.get("HTTPStatusCode", 500) >= 400:
|
||||||
raise BedrockError(
|
raise BedrockError(
|
||||||
message=outputText,
|
message=outputText,
|
||||||
status_code=response.get("HTTPStatusCode", 500),
|
status_code=response_metadata.get("HTTPStatusCode", 500),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
if len(outputText) > 0:
|
if len(outputText) > 0:
|
||||||
model_response["choices"][0]["message"]["content"] = outputText
|
model_response["choices"][0]["message"]["content"] = outputText
|
||||||
except:
|
except:
|
||||||
raise BedrockError(message=json.dumps(outputText), status_code=response.status_code)
|
raise BedrockError(message=json.dumps(outputText), status_code=response_metadata.get("HTTPStatusCode", 500))
|
||||||
|
|
||||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||||
prompt_tokens = len(
|
prompt_tokens = len(
|
||||||
|
|
|
@ -73,6 +73,13 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
|
||||||
if model == "gpt-3.5-turbo":
|
if model == "gpt-3.5-turbo":
|
||||||
temporary_key = os.environ["OPENAI_API_KEY"]
|
temporary_key = os.environ["OPENAI_API_KEY"]
|
||||||
os.environ["OPENAI_API_KEY"] = "bad-key"
|
os.environ["OPENAI_API_KEY"] = "bad-key"
|
||||||
|
elif model == "bedrock/anthropic.claude-v2":
|
||||||
|
temporary_aws_access_key = os.environ["AWS_ACCESS_KEY_ID"]
|
||||||
|
os.environ["AWS_ACCESS_KEY_ID"] = "bad-key"
|
||||||
|
temporary_aws_region_name = os.environ["AWS_REGION_NAME"]
|
||||||
|
os.environ["AWS_REGION_NAME"] = "bad-key"
|
||||||
|
temporary_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"]
|
||||||
|
os.environ["AWS_SECRET_ACCESS_KEY"] = "bad-key"
|
||||||
elif model == "chatgpt-test":
|
elif model == "chatgpt-test":
|
||||||
temporary_key = os.environ["AZURE_API_KEY"]
|
temporary_key = os.environ["AZURE_API_KEY"]
|
||||||
os.environ["AZURE_API_KEY"] = "bad-key"
|
os.environ["AZURE_API_KEY"] = "bad-key"
|
||||||
|
@ -109,10 +116,10 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
|
||||||
)
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
except AuthenticationError as e:
|
except AuthenticationError as e:
|
||||||
print(f"AuthenticationError Caught Exception - {e.llm_provider}")
|
print(f"AuthenticationError Caught Exception - {str(e)}")
|
||||||
except (
|
except (
|
||||||
OpenAIError
|
OpenAIError
|
||||||
): # is at least an openai error -> in case of random model errors - e.g. overloaded server
|
) as e: # is at least an openai error -> in case of random model errors - e.g. overloaded server
|
||||||
print(f"OpenAIError Caught Exception - {e}")
|
print(f"OpenAIError Caught Exception - {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(type(e))
|
print(type(e))
|
||||||
|
@ -143,8 +150,15 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
|
||||||
os.environ["ALEPH_ALPHA_API_KEY"] = temporary_key
|
os.environ["ALEPH_ALPHA_API_KEY"] = temporary_key
|
||||||
elif model in litellm.nlp_cloud_models:
|
elif model in litellm.nlp_cloud_models:
|
||||||
os.environ["NLP_CLOUD_API_KEY"] = temporary_key
|
os.environ["NLP_CLOUD_API_KEY"] = temporary_key
|
||||||
|
elif "bedrock" in model:
|
||||||
|
os.environ["AWS_ACCESS_KEY_ID"] = temporary_aws_access_key
|
||||||
|
os.environ["AWS_REGION_NAME"] = temporary_aws_region_name
|
||||||
|
os.environ["AWS_SECRET_ACCESS_KEY"] = temporary_secret_key
|
||||||
return
|
return
|
||||||
|
|
||||||
|
for model in litellm.models_by_provider["bedrock"]:
|
||||||
|
invalid_auth(model=model)
|
||||||
|
|
||||||
# Test 3: Invalid Request Error
|
# Test 3: Invalid Request Error
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
def test_invalid_request_error(model):
|
def test_invalid_request_error(model):
|
||||||
|
|
|
@ -2946,14 +2946,14 @@ def exception_type(
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="bedrock"
|
llm_provider="bedrock"
|
||||||
)
|
)
|
||||||
if "Unable to locate credentials" in error_str or "Malformed input request" in error_str:
|
if "Malformed input request" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise InvalidRequestError(
|
raise InvalidRequestError(
|
||||||
message=f"BedrockException - {error_str}",
|
message=f"BedrockException - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="bedrock"
|
llm_provider="bedrock"
|
||||||
)
|
)
|
||||||
if "The security token included in the request is invalid" in error_str:
|
if "Unable to locate credentials" in error_str or "The security token included in the request is invalid" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise AuthenticationError(
|
raise AuthenticationError(
|
||||||
message=f"BedrockException Invalid Authentication - {error_str}",
|
message=f"BedrockException Invalid Authentication - {error_str}",
|
||||||
|
@ -2975,6 +2975,13 @@ def exception_type(
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
model=model
|
model=model
|
||||||
)
|
)
|
||||||
|
elif original_exception.status_code == 401:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise AuthenticationError(
|
||||||
|
message=f"BedrockException - {original_exception.message}",
|
||||||
|
llm_provider="bedrock",
|
||||||
|
model=model
|
||||||
|
)
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
if "Unable to locate credentials" in error_str:
|
if "Unable to locate credentials" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue