(feat) provisioned throughput - bedrock embedding models

This commit is contained in:
ishaan-jaff 2024-01-13 21:07:38 -08:00
parent 719b051b3d
commit 069d060ec9

View file

@ -667,6 +667,9 @@ def _embedding_func_single(
inference_params.pop( inference_params.pop(
"user", None "user", None
) # make sure user is not passed in for bedrock call ) # make sure user is not passed in for bedrock call
modelId = (
optional_params.pop("model_id", None) or model
) # default to model if not passed
if provider == "amazon": if provider == "amazon":
input = input.replace(os.linesep, " ") input = input.replace(os.linesep, " ")
data = {"inputText": input, **inference_params} data = {"inputText": input, **inference_params}
@ -681,7 +684,7 @@ def _embedding_func_single(
request_str = f""" request_str = f"""
response = client.invoke_model( response = client.invoke_model(
body={body}, body={body},
modelId={model}, modelId={modelId},
accept="*/*", accept="*/*",
contentType="application/json", contentType="application/json",
)""" # type: ignore )""" # type: ignore
@ -689,14 +692,14 @@ def _embedding_func_single(
input=input, input=input,
api_key="", # boto3 is used for init. api_key="", # boto3 is used for init.
additional_args={ additional_args={
"complete_input_dict": {"model": model, "texts": input}, "complete_input_dict": {"model": modelId, "texts": input},
"request_str": request_str, "request_str": request_str,
}, },
) )
try: try:
response = client.invoke_model( response = client.invoke_model(
body=body, body=body,
modelId=model, modelId=modelId,
accept="*/*", accept="*/*",
contentType="application/json", contentType="application/json",
) )