fix(converse_transformation.py): fix encoding model

This commit is contained in:
Krrish Dholakia 2025-03-15 14:03:37 -07:00
parent 814d8ba54c
commit 8e7363acf5
4 changed files with 441 additions and 324 deletions

View file

@ -268,23 +268,29 @@ class BedrockConverseLLM(BaseAWSLLM):
## SETUP ##
stream = optional_params.pop("stream", None)
modelId = optional_params.pop("model_id", None)
unencoded_model_id = optional_params.pop("model_id", None)
fake_stream = optional_params.pop("fake_stream", False)
json_mode = optional_params.get("json_mode", False)
if modelId is not None:
modelId = self.encode_model_id(model_id=modelId)
if unencoded_model_id is not None:
modelId = self.encode_model_id(model_id=unencoded_model_id)
else:
modelId = self.encode_model_id(model_id=model)
if stream is True and "ai21" in modelId:
fake_stream = True
### SET REGION NAME ###
aws_region_name = self._get_aws_region_name(
optional_params=optional_params,
model=model,
model_id=unencoded_model_id,
)
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
@ -293,25 +299,25 @@ class BedrockConverseLLM(BaseAWSLLM):
) # https://bedrock-runtime.{region_name}.amazonaws.com
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
optional_params.pop("aws_region_name", None)
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
# if aws_region_name is None:
# # check env #
# litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
# if litellm_aws_region_name is not None and isinstance(
# litellm_aws_region_name, str
# ):
# aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
# standard_aws_region_name = get_secret("AWS_REGION", None)
# if standard_aws_region_name is not None and isinstance(
# standard_aws_region_name, str
# ):
# aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
# if aws_region_name is None:
# aws_region_name = "us-west-2"
litellm_params["aws_region_name"] = (
aws_region_name # [DO NOT DELETE] important for async calls