fix(bedrock.py): fix bedrock exception mapping

This commit is contained in:
Krrish Dholakia 2023-11-03 18:11:50 -07:00
parent 53371d37b7
commit 142750adff
3 changed files with 56 additions and 18 deletions

View file

@ -364,12 +364,16 @@ def completion(
response = response.get('body') response = response.get('body')
return response return response
try:
response = client.invoke_model( response = client.invoke_model(
body=data, body=data,
modelId=model, modelId=model,
accept=accept, accept=accept,
contentType=contentType contentType=contentType
) )
except Exception as e:
raise BedrockError(status_code=500, message=str(e))
response_body = json.loads(response.get('body').read()) response_body = json.loads(response.get('body').read())
## LOGGING ## LOGGING
@ -391,6 +395,15 @@ def completion(
outputText = response_body["generations"][0]["text"] outputText = response_body["generations"][0]["text"]
else: # amazon titan else: # amazon titan
outputText = response_body.get('results')[0].get('outputText') outputText = response_body.get('results')[0].get('outputText')
response_metadata = response.get("ResponseMetadata", {})
if response_metadata.get("HTTPStatusCode", 500) >= 400:
raise BedrockError(
message=outputText,
status_code=response.get("HTTPStatusCode", 500),
)
else:
try: try:
if len(outputText) > 0: if len(outputText) > 0:
model_response["choices"][0]["message"]["content"] = outputText model_response["choices"][0]["message"]["content"] = outputText

View file

@ -38,11 +38,19 @@ models = ["command-nightly"]
# Test 1: Context Window Errors # Test 1: Context Window Errors
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_context_window(model): def test_context_window(model):
sample_text = "how does a court case get to the Supreme Court?" * 1000 sample_text = "Say error 50 times" * 100000
messages = [{"content": sample_text, "role": "user"}] messages = [{"content": sample_text, "role": "user"}]
print(f"model: {model}")
with pytest.raises(ContextWindowExceededError): try:
completion(model=model, messages=messages) completion(model=model, messages=messages)
pytest.fail(f"An exception occurred")
except ContextWindowExceededError:
pass
except RateLimitError:
pass
except Exception as e:
print(f"{e}")
pytest.fail(f"An error occcurred - {e}")
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_context_window_with_fallbacks(model): def test_context_window_with_fallbacks(model):
@ -52,8 +60,10 @@ def test_context_window_with_fallbacks(model):
completion(model=model, messages=messages, context_window_fallback_dict=ctx_window_fallback_dict) completion(model=model, messages=messages, context_window_fallback_dict=ctx_window_fallback_dict)
# for model in litellm.models_by_provider["bedrock"]:
# test_context_window(model=model)
# test_context_window(model="command-nightly") # test_context_window(model="command-nightly")
test_context_window_with_fallbacks(model="command-nightly") # test_context_window_with_fallbacks(model="command-nightly")
# Test 2: InvalidAuth Errors # Test 2: InvalidAuth Errors
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def invalid_auth(model): # set the model key to an invalid key, depending on the model def invalid_auth(model): # set the model key to an invalid key, depending on the model

View file

@ -2939,7 +2939,14 @@ def exception_type(
model=model model=model
) )
elif custom_llm_provider == "bedrock": elif custom_llm_provider == "bedrock":
if "Unable to locate credentials" in error_str: if "too many tokens" in error_str or "expected maxLength:" in error_str or "Input is too long" in error_str or "Too many input tokens" in error_str:
exception_mapping_worked = True
raise ContextWindowExceededError(
message=f"BedrockException: Context Window Error - {error_str}",
model=model,
llm_provider="bedrock"
)
if "Unable to locate credentials" in error_str or "Malformed input request" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise InvalidRequestError( raise InvalidRequestError(
message=f"BedrockException - {error_str}", message=f"BedrockException - {error_str}",
@ -2953,13 +2960,21 @@ def exception_type(
model=model, model=model,
llm_provider="bedrock" llm_provider="bedrock"
) )
if "throttlingException" in error_str: if "throttlingException" in error_str or "ThrottlingException" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise RateLimitError( raise RateLimitError(
message=f"BedrockException: Rate Limit Error - {error_str}", message=f"BedrockException: Rate Limit Error - {error_str}",
model=model, model=model,
llm_provider="bedrock" llm_provider="bedrock"
) )
if hasattr(original_exception, "status_code"):
if original_exception.status_code == 500:
exception_mapping_worked = True
raise ServiceUnavailableError(
message=f"BedrockException - {original_exception.message}",
llm_provider="bedrock",
model=model
)
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
if "Unable to locate credentials" in error_str: if "Unable to locate credentials" in error_str:
exception_mapping_worked = True exception_mapping_worked = True