Merge branch 'main' into main

This commit is contained in:
Lucca Zenóbio 2024-05-02 09:46:34 -03:00 committed by GitHub
commit 78303b79ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
124 changed files with 6716 additions and 1078 deletions

View file

@ -29,6 +29,24 @@ class BedrockError(Exception):
) # Call the base class constructor with the parameters it needs
class AmazonBedrockGlobalConfig:
def __init__(self):
pass
def get_mapped_special_auth_params(self) -> dict:
"""
Mapping of common auth params across bedrock/vertex/azure/watsonx
"""
return {"region_name": "aws_region_name"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
class AmazonTitanConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
@ -666,6 +684,10 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "meta":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
else:
prompt = ""
for message in messages:
@ -945,7 +967,7 @@ def completion(
original_response=json.dumps(response_body),
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response}")
print_verbose(f"raw model_response: {response_body}")
## RESPONSE OBJECT
outputText = "default"
if provider == "ai21":
@ -1058,6 +1080,7 @@ def completion(
outputText = response_body.get("results")[0].get("outputText")
response_metadata = response.get("ResponseMetadata", {})
if response_metadata.get("HTTPStatusCode", 500) >= 400:
raise BedrockError(
message=outputText,
@ -1093,11 +1116,13 @@ def completion(
prompt_tokens = response_metadata.get(
"x-amzn-bedrock-input-token-count", len(encoding.encode(prompt))
)
_text_response = model_response["choices"][0]["message"].get("content", "")
completion_tokens = response_metadata.get(
"x-amzn-bedrock-output-token-count",
len(
encoding.encode(
model_response["choices"][0]["message"].get("content", "")
_text_response,
disallowed_special=(),
)
),
)