diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 3b9d71f775..85b820c9e8 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -667,6 +667,9 @@ def _embedding_func_single( inference_params.pop( "user", None ) # make sure user is not passed in for bedrock call + modelId = ( + optional_params.pop("model_id", None) or model + ) # default to model if not passed if provider == "amazon": input = input.replace(os.linesep, " ") data = {"inputText": input, **inference_params} @@ -681,7 +684,7 @@ def _embedding_func_single( request_str = f""" response = client.invoke_model( body={body}, - modelId={model}, + modelId={modelId}, accept="*/*", contentType="application/json", )""" # type: ignore @@ -689,14 +692,14 @@ def _embedding_func_single( input=input, api_key="", # boto3 is used for init. additional_args={ - "complete_input_dict": {"model": model, "texts": input}, + "complete_input_dict": {"model": modelId, "texts": input}, "request_str": request_str, }, ) try: response = client.invoke_model( body=body, - modelId=model, + modelId=modelId, accept="*/*", contentType="application/json", )