Merge pull request #845 from canada4663/upstream-main

Added support for multiple embeddings via Bedrock
This commit is contained in:
Krish Dholakia 2023-11-21 14:00:06 -08:00 committed by GitHub
commit e4f1e2b138
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 60 additions and 31 deletions

View file

@ -483,15 +483,11 @@ def completion(
import traceback
raise BedrockError(status_code=500, message=traceback.format_exc())
def embedding(
model: str,
input: list,
logging_obj=None,
model_response=None,
optional_params=None,
encoding=None,
def _embedding_func_single(
model: str,
input: str,
optional_params=None,
encoding=None,
):
# logic for parsing in - calling - parsing out model embedding calls
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
@ -509,38 +505,61 @@ def embedding(
aws_region_name=aws_region_name,
),
)
# translate to bedrock
# bedrock only accepts (str) for inputText
if type(input) == list:
if len(input) > 1: # input is a list with more than 1 elem, raise Exception, Bedrock only supports one element
raise BedrockError(message="Bedrock cannot embed() more than one string - len(input) must always == 1, input = ['hi from litellm']", status_code=400)
input_str = "".join(input)
response = client.invoke_model(
body=json.dumps({
"inputText": input_str
}),
modelId=model,
accept="*/*",
contentType="application/json"
input = input.replace(os.linesep, " ")
body = json.dumps({"inputText": input})
try:
response = client.invoke_model(
body=body,
modelId=model,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
return response_body.get("embedding")
except Exception as e:
raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500)
def embedding(
model: str,
input: list,
api_key: Optional[str] = None,
logging_obj=None,
model_response=None,
optional_params=None,
encoding=None,
):
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": {"model": model,
"texts": input}},
)
response_body = json.loads(response.get('body').read())
## Embedding Call
embeddings = [_embedding_func_single(model, i, optional_params) for i in input]
embedding_response = response_body["embedding"]
## Populate OpenAI compliant dictionary
embedding_response = []
for idx, embedding in enumerate(embeddings):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding,
}
)
model_response["object"] = "list"
model_response["data"] = embedding_response
model_response["model"] = model
input_tokens = 0
input_tokens+=len(encoding.encode(input_str))
input_str = "".join(input)
model_response["usage"] = {
"prompt_tokens": input_tokens,
"total_tokens": input_tokens,
}
input_tokens+=len(encoding.encode(input_str))
usage = Usage(
prompt_tokens=input_tokens,
@ -549,4 +568,13 @@ def embedding(
)
model_response.usage = usage
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": {"model": model,
"texts": input}},
original_response=embeddings,
)
return model_response

View file

@ -95,7 +95,8 @@ def test_cohere_embedding3():
def test_bedrock_embedding():
try:
response = embedding(
model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data"]
model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data",
"lets test a second string for good measure"]
)
print(f"response:", response)
except Exception as e: