refactor(bedrock.py): better exception mapping for bedrock + huggingface

This commit is contained in:
Krrish Dholakia 2023-11-04 16:12:12 -07:00
parent ab54262d37
commit 7c46e85ed6
2 changed files with 158 additions and 141 deletions

View file

@ -278,151 +278,163 @@ def completion(
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them exception_mapping_worked = False
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = optional_params.pop(
"aws_bedrock_client",
# only pass variables that are not None
init_bedrock_client(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
),
)
model = model
provider = model.split(".")[0]
prompt = convert_messages_to_prompt(model, messages, provider, custom_prompt_dict)
inference_params = copy.deepcopy(optional_params)
stream = inference_params.pop("stream", False)
if provider == "anthropic":
## LOAD CONFIG
config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({
"prompt": prompt,
**inference_params
})
elif provider == "ai21":
## LOAD CONFIG
config = litellm.AmazonAI21Config.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({
"prompt": prompt,
**inference_params
})
elif provider == "cohere":
## LOAD CONFIG
config = litellm.AmazonCohereConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
if optional_params.get("stream", False) == True:
inference_params["stream"] = True # cohere requires stream = True in inference params
data = json.dumps({
"prompt": prompt,
**inference_params
})
elif provider == "amazon": # amazon titan
## LOAD CONFIG
config = litellm.AmazonTitanConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({
"inputText": prompt,
"textGenerationConfig": inference_params,
})
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
accept = 'application/json'
contentType = 'application/json'
if stream == True:
response = client.invoke_model_with_response_stream(
body=data,
modelId=model,
accept=accept,
contentType=contentType
)
response = response.get('body')
return response
try: try:
response = client.invoke_model( # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
body=data, aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
modelId=model, aws_access_key_id = optional_params.pop("aws_access_key_id", None)
accept=accept, aws_region_name = optional_params.pop("aws_region_name", None)
contentType=contentType
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = optional_params.pop(
"aws_bedrock_client",
# only pass variables that are not None
init_bedrock_client(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
),
) )
except Exception as e:
raise BedrockError(status_code=500, message=str(e))
response_body = json.loads(response.get('body').read()) model = model
provider = model.split(".")[0]
prompt = convert_messages_to_prompt(model, messages, provider, custom_prompt_dict)
inference_params = copy.deepcopy(optional_params)
stream = inference_params.pop("stream", False)
if provider == "anthropic":
## LOAD CONFIG
config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({
"prompt": prompt,
**inference_params
})
elif provider == "ai21":
## LOAD CONFIG
config = litellm.AmazonAI21Config.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
## LOGGING data = json.dumps({
logging_obj.post_call( "prompt": prompt,
input=prompt, **inference_params
api_key="", })
original_response=response_body, elif provider == "cohere":
additional_args={"complete_input_dict": data}, ## LOAD CONFIG
) config = litellm.AmazonCohereConfig.get_config()
print_verbose(f"raw model_response: {response}") for k, v in config.items():
## RESPONSE OBJECT if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
outputText = "default" inference_params[k] = v
if provider == "ai21": if optional_params.get("stream", False) == True:
outputText = response_body.get('completions')[0].get('data').get('text') inference_params["stream"] = True # cohere requires stream = True in inference params
elif provider == "anthropic": data = json.dumps({
outputText = response_body['completion'] "prompt": prompt,
model_response["finish_reason"] = response_body["stop_reason"] **inference_params
elif provider == "cohere": })
outputText = response_body["generations"][0]["text"] elif provider == "amazon": # amazon titan
else: # amazon titan ## LOAD CONFIG
outputText = response_body.get('results')[0].get('outputText') config = litellm.AmazonTitanConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
response_metadata = response.get("ResponseMetadata", {}) data = json.dumps({
if response_metadata.get("HTTPStatusCode", 500) >= 400: "inputText": prompt,
raise BedrockError( "textGenerationConfig": inference_params,
message=outputText, })
status_code=response_metadata.get("HTTPStatusCode", 500),
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data},
) )
else:
## COMPLETION CALL
accept = 'application/json'
contentType = 'application/json'
if stream == True:
response = client.invoke_model_with_response_stream(
body=data,
modelId=model,
accept=accept,
contentType=contentType
)
response = response.get('body')
return response
try: try:
if len(outputText) > 0: response = client.invoke_model(
model_response["choices"][0]["message"]["content"] = outputText body=data,
except: modelId=model,
raise BedrockError(message=json.dumps(outputText), status_code=response_metadata.get("HTTPStatusCode", 500)) accept=accept,
contentType=contentType
)
except Exception as e:
raise BedrockError(status_code=500, message=str(e))
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. response_body = json.loads(response.get('body').read())
prompt_tokens = len(
encoding.encode(prompt) ## LOGGING
) logging_obj.post_call(
completion_tokens = len( input=prompt,
encoding.encode(model_response["choices"][0]["message"].get("content", "")) api_key="",
) original_response=response_body,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT
outputText = "default"
if provider == "ai21":
outputText = response_body.get('completions')[0].get('data').get('text')
elif provider == "anthropic":
outputText = response_body['completion']
model_response["finish_reason"] = response_body["stop_reason"]
elif provider == "cohere":
outputText = response_body["generations"][0]["text"]
else: # amazon titan
outputText = response_body.get('results')[0].get('outputText')
response_metadata = response.get("ResponseMetadata", {})
if response_metadata.get("HTTPStatusCode", 500) >= 400:
raise BedrockError(
message=outputText,
status_code=response_metadata.get("HTTPStatusCode", 500),
)
else:
try:
if len(outputText) > 0:
model_response["choices"][0]["message"]["content"] = outputText
except:
raise BedrockError(message=json.dumps(outputText), status_code=response_metadata.get("HTTPStatusCode", 500))
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(
encoding.encode(prompt)
)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response["created"] = time.time()
model_response["model"] = model
model_response.usage.completion_tokens = completion_tokens
model_response.usage.prompt_tokens = prompt_tokens
model_response.usage.total_tokens = prompt_tokens + completion_tokens
return model_response
except BedrockError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
raise e
else:
import traceback
raise BedrockError(status_code=500, message=traceback.format_exc())
model_response["created"] = time.time()
model_response["model"] = model
model_response.usage.completion_tokens = completion_tokens
model_response.usage.prompt_tokens = prompt_tokens
model_response.usage.total_tokens = prompt_tokens + completion_tokens
return model_response
def embedding( def embedding(

View file

@ -141,6 +141,7 @@ def completion(
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
exception_mapping_worked = False
try: try:
headers = validate_environment(api_key, headers) headers = validate_environment(api_key, headers)
task = get_hf_task_for_model(model) task = get_hf_task_for_model(model)
@ -365,10 +366,14 @@ def completion(
model_response._hidden_params["original_response"] = completion_response model_response._hidden_params["original_response"] = completion_response
return model_response return model_response
except HuggingfaceError as e: except HuggingfaceError as e:
exception_mapping_worked = True
raise e raise e
except Exception as e: except Exception as e:
import traceback if exception_mapping_worked:
raise HuggingfaceError(status_code=500, message=traceback.format_exc()) raise e
else:
import traceback
raise HuggingfaceError(status_code=500, message=traceback.format_exc())
def embedding( def embedding(