diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 0b61f6bba8..c106555f0f 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -364,12 +364,16 @@ def completion( response = response.get('body') return response - response = client.invoke_model( - body=data, - modelId=model, - accept=accept, - contentType=contentType - ) + try: + response = client.invoke_model( + body=data, + modelId=model, + accept=accept, + contentType=contentType + ) + except Exception as e: + raise BedrockError(status_code=500, message=str(e)) + response_body = json.loads(response.get('body').read()) ## LOGGING @@ -391,11 +395,20 @@ def completion( outputText = response_body["generations"][0]["text"] else: # amazon titan outputText = response_body.get('results')[0].get('outputText') - try: - if len(outputText) > 0: - model_response["choices"][0]["message"]["content"] = outputText - except: - raise BedrockError(message=json.dumps(outputText), status_code=response.status_code) + + response_metadata = response.get("ResponseMetadata", {}) + + if response_metadata.get("HTTPStatusCode", 500) >= 400: + raise BedrockError( + message=outputText, + status_code=response.get("HTTPStatusCode", 500), + ) + else: + try: + if len(outputText) > 0: + model_response["choices"][0]["message"]["content"] = outputText + except: + raise BedrockError(message=json.dumps(outputText), status_code=response.status_code) ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. prompt_tokens = len( diff --git a/litellm/tests/test_exceptions.py b/litellm/tests/test_exceptions.py index b7cb8375ac..950443171b 100644 --- a/litellm/tests/test_exceptions.py +++ b/litellm/tests/test_exceptions.py @@ -38,12 +38,20 @@ models = ["command-nightly"] # Test 1: Context Window Errors @pytest.mark.parametrize("model", models) def test_context_window(model): - sample_text = "how does a court case get to the Supreme Court?" * 1000 + sample_text = "Say error 50 times" * 100000 messages = [{"content": sample_text, "role": "user"}] - - with pytest.raises(ContextWindowExceededError): + print(f"model: {model}") + try: completion(model=model, messages=messages) - + pytest.fail(f"An exception occurred") + except ContextWindowExceededError: + pass + except RateLimitError: + pass + except Exception as e: + print(f"{e}") + pytest.fail(f"An error occcurred - {e}") + @pytest.mark.parametrize("model", models) def test_context_window_with_fallbacks(model): ctx_window_fallback_dict = {"command-nightly": "claude-2"} @@ -52,8 +60,10 @@ def test_context_window_with_fallbacks(model): completion(model=model, messages=messages, context_window_fallback_dict=ctx_window_fallback_dict) +# for model in litellm.models_by_provider["bedrock"]: +# test_context_window(model=model) # test_context_window(model="command-nightly") -test_context_window_with_fallbacks(model="command-nightly") +# test_context_window_with_fallbacks(model="command-nightly") # Test 2: InvalidAuth Errors @pytest.mark.parametrize("model", models) def invalid_auth(model): # set the model key to an invalid key, depending on the model diff --git a/litellm/utils.py b/litellm/utils.py index 3f56e2ceb8..6a94afb515 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2939,7 +2939,14 @@ def exception_type( model=model ) elif custom_llm_provider == "bedrock": - if "Unable to locate credentials" in error_str: + if "too many tokens" in error_str or "expected maxLength:" in error_str or "Input is too long" in error_str or "Too many input tokens" in error_str: + exception_mapping_worked = True + raise ContextWindowExceededError( + message=f"BedrockException: Context Window Error - {error_str}", + model=model, + llm_provider="bedrock" + ) + if "Unable to locate credentials" in error_str or "Malformed input request" in error_str: exception_mapping_worked = True raise InvalidRequestError( message=f"BedrockException - {error_str}", @@ -2953,13 +2960,21 @@ def exception_type( model=model, llm_provider="bedrock" ) - if "throttlingException" in error_str: + if "throttlingException" in error_str or "ThrottlingException" in error_str: exception_mapping_worked = True raise RateLimitError( message=f"BedrockException: Rate Limit Error - {error_str}", model=model, llm_provider="bedrock" ) + if hasattr(original_exception, "status_code"): + if original_exception.status_code == 500: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"BedrockException - {original_exception.message}", + llm_provider="bedrock", + model=model + ) elif custom_llm_provider == "sagemaker": if "Unable to locate credentials" in error_str: exception_mapping_worked = True