diff --git a/litellm/tests/test_get_optional_params_embeddings.py b/litellm/tests/test_get_optional_params_embeddings.py index 41396b531..81b177030 100644 --- a/litellm/tests/test_get_optional_params_embeddings.py +++ b/litellm/tests/test_get_optional_params_embeddings.py @@ -40,3 +40,32 @@ def test_vertex_projects(): # test_vertex_projects() + + +def test_bedrock_embed_v2_regular(): + model, custom_llm_provider, _, _ = get_llm_provider( + model="bedrock/amazon.titan-embed-text-v2:0" + ) + optional_params = get_optional_params_embeddings( + model=model, + dimensions=512, + custom_llm_provider=custom_llm_provider, + ) + print(f"received optional_params: {optional_params}") + assert optional_params == {"dimensions": 512} + + +def test_bedrock_embed_v2_with_drop_params(): + litellm.drop_params = True + model, custom_llm_provider, _, _ = get_llm_provider( + model="bedrock/amazon.titan-embed-text-v2:0" + ) + optional_params = get_optional_params_embeddings( + model=model, + dimensions=512, + user="test-litellm-user-5", + encoding_format="base64", + custom_llm_provider=custom_llm_provider, + ) + print(f"received optional_params: {optional_params}") + assert optional_params == {"dimensions": 512} diff --git a/litellm/utils.py b/litellm/utils.py index e124358b6..5070e6498 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4744,9 +4744,14 @@ def get_optional_params_embeddings( message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.", ) if custom_llm_provider == "bedrock": - if "amazon.titan-embed-text-v2" in model: - # embed-text-v2 supports the dimension param + # if dimensions is in non_default_params -> pass it for model=bedrock/amazon.titan-embed-text-v2 + if ( + "dimensions" in non_default_params.keys() + and "amazon.titan-embed-text-v2" in model + ): + kwargs["dimensions"] = non_default_params["dimensions"] non_default_params.pop("dimensions", None) + if len(non_default_params.keys()) > 0: if litellm.drop_params is True: # drop the unsupported non-default values keys = list(non_default_params.keys()) @@ -4758,6 +4763,7 @@ def get_optional_params_embeddings( status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.", ) + return {**non_default_params, **kwargs} if ( custom_llm_provider != "openai"