mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(bedrock.py): fix bedrock exception mapping
This commit is contained in:
parent
53371d37b7
commit
142750adff
3 changed files with 56 additions and 18 deletions
|
@ -364,12 +364,16 @@ def completion(
|
||||||
response = response.get('body')
|
response = response.get('body')
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
try:
|
||||||
response = client.invoke_model(
|
response = client.invoke_model(
|
||||||
body=data,
|
body=data,
|
||||||
modelId=model,
|
modelId=model,
|
||||||
accept=accept,
|
accept=accept,
|
||||||
contentType=contentType
|
contentType=contentType
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise BedrockError(status_code=500, message=str(e))
|
||||||
|
|
||||||
response_body = json.loads(response.get('body').read())
|
response_body = json.loads(response.get('body').read())
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -391,6 +395,15 @@ def completion(
|
||||||
outputText = response_body["generations"][0]["text"]
|
outputText = response_body["generations"][0]["text"]
|
||||||
else: # amazon titan
|
else: # amazon titan
|
||||||
outputText = response_body.get('results')[0].get('outputText')
|
outputText = response_body.get('results')[0].get('outputText')
|
||||||
|
|
||||||
|
response_metadata = response.get("ResponseMetadata", {})
|
||||||
|
|
||||||
|
if response_metadata.get("HTTPStatusCode", 500) >= 400:
|
||||||
|
raise BedrockError(
|
||||||
|
message=outputText,
|
||||||
|
status_code=response.get("HTTPStatusCode", 500),
|
||||||
|
)
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
if len(outputText) > 0:
|
if len(outputText) > 0:
|
||||||
model_response["choices"][0]["message"]["content"] = outputText
|
model_response["choices"][0]["message"]["content"] = outputText
|
||||||
|
|
|
@ -38,11 +38,19 @@ models = ["command-nightly"]
|
||||||
# Test 1: Context Window Errors
|
# Test 1: Context Window Errors
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
def test_context_window(model):
|
def test_context_window(model):
|
||||||
sample_text = "how does a court case get to the Supreme Court?" * 1000
|
sample_text = "Say error 50 times" * 100000
|
||||||
messages = [{"content": sample_text, "role": "user"}]
|
messages = [{"content": sample_text, "role": "user"}]
|
||||||
|
print(f"model: {model}")
|
||||||
with pytest.raises(ContextWindowExceededError):
|
try:
|
||||||
completion(model=model, messages=messages)
|
completion(model=model, messages=messages)
|
||||||
|
pytest.fail(f"An exception occurred")
|
||||||
|
except ContextWindowExceededError:
|
||||||
|
pass
|
||||||
|
except RateLimitError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
print(f"{e}")
|
||||||
|
pytest.fail(f"An error occcurred - {e}")
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
def test_context_window_with_fallbacks(model):
|
def test_context_window_with_fallbacks(model):
|
||||||
|
@ -52,8 +60,10 @@ def test_context_window_with_fallbacks(model):
|
||||||
|
|
||||||
completion(model=model, messages=messages, context_window_fallback_dict=ctx_window_fallback_dict)
|
completion(model=model, messages=messages, context_window_fallback_dict=ctx_window_fallback_dict)
|
||||||
|
|
||||||
|
# for model in litellm.models_by_provider["bedrock"]:
|
||||||
|
# test_context_window(model=model)
|
||||||
# test_context_window(model="command-nightly")
|
# test_context_window(model="command-nightly")
|
||||||
test_context_window_with_fallbacks(model="command-nightly")
|
# test_context_window_with_fallbacks(model="command-nightly")
|
||||||
# Test 2: InvalidAuth Errors
|
# Test 2: InvalidAuth Errors
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
def invalid_auth(model): # set the model key to an invalid key, depending on the model
|
def invalid_auth(model): # set the model key to an invalid key, depending on the model
|
||||||
|
|
|
@ -2939,7 +2939,14 @@ def exception_type(
|
||||||
model=model
|
model=model
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "bedrock":
|
elif custom_llm_provider == "bedrock":
|
||||||
if "Unable to locate credentials" in error_str:
|
if "too many tokens" in error_str or "expected maxLength:" in error_str or "Input is too long" in error_str or "Too many input tokens" in error_str:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise ContextWindowExceededError(
|
||||||
|
message=f"BedrockException: Context Window Error - {error_str}",
|
||||||
|
model=model,
|
||||||
|
llm_provider="bedrock"
|
||||||
|
)
|
||||||
|
if "Unable to locate credentials" in error_str or "Malformed input request" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise InvalidRequestError(
|
raise InvalidRequestError(
|
||||||
message=f"BedrockException - {error_str}",
|
message=f"BedrockException - {error_str}",
|
||||||
|
@ -2953,13 +2960,21 @@ def exception_type(
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="bedrock"
|
llm_provider="bedrock"
|
||||||
)
|
)
|
||||||
if "throttlingException" in error_str:
|
if "throttlingException" in error_str or "ThrottlingException" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise RateLimitError(
|
raise RateLimitError(
|
||||||
message=f"BedrockException: Rate Limit Error - {error_str}",
|
message=f"BedrockException: Rate Limit Error - {error_str}",
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="bedrock"
|
llm_provider="bedrock"
|
||||||
)
|
)
|
||||||
|
if hasattr(original_exception, "status_code"):
|
||||||
|
if original_exception.status_code == 500:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise ServiceUnavailableError(
|
||||||
|
message=f"BedrockException - {original_exception.message}",
|
||||||
|
llm_provider="bedrock",
|
||||||
|
model=model
|
||||||
|
)
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
if "Unable to locate credentials" in error_str:
|
if "Unable to locate credentials" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue