diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 16a0abbed7..b67061c76b 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -702,7 +702,7 @@ def _embedding_func_single( encoding=None, logging_obj=None, ): - if type(input) != str: + if isinstance(input, str) is False: raise BedrockError( message="Bedrock Embedding API input must be type str | List[str]", status_code=400, @@ -800,7 +800,8 @@ def embedding( aws_role_name=aws_role_name, aws_session_name=aws_session_name, ) - if type(input) == str: + if isinstance(input, str): + ## Embedding Call embeddings = [ _embedding_func_single( model, @@ -810,8 +811,8 @@ def embedding( logging_obj=logging_obj, ) ] - elif type(input) == list: - ## Embedding Call + elif isinstance(input, list): + ## Embedding Call - assuming this is a List[str] embeddings = [ _embedding_func_single( model,