forked from phoenix/litellm-mirror
Merge pull request #845 from canada4663/upstream-main
Added support for multiple embeddings via Bedrock
This commit is contained in:
commit
e4f1e2b138
2 changed files with 60 additions and 31 deletions
|
@ -483,15 +483,11 @@ def completion(
|
|||
import traceback
|
||||
raise BedrockError(status_code=500, message=traceback.format_exc())
|
||||
|
||||
|
||||
|
||||
def embedding(
|
||||
model: str,
|
||||
input: list,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
optional_params=None,
|
||||
encoding=None,
|
||||
def _embedding_func_single(
|
||||
model: str,
|
||||
input: str,
|
||||
optional_params=None,
|
||||
encoding=None,
|
||||
):
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
|
@ -509,38 +505,61 @@ def embedding(
|
|||
aws_region_name=aws_region_name,
|
||||
),
|
||||
)
|
||||
|
||||
# translate to bedrock
|
||||
# bedrock only accepts (str) for inputText
|
||||
if type(input) == list:
|
||||
if len(input) > 1: # input is a list with more than 1 elem, raise Exception, Bedrock only supports one element
|
||||
raise BedrockError(message="Bedrock cannot embed() more than one string - len(input) must always == 1, input = ['hi from litellm']", status_code=400)
|
||||
input_str = "".join(input)
|
||||
|
||||
response = client.invoke_model(
|
||||
body=json.dumps({
|
||||
"inputText": input_str
|
||||
}),
|
||||
modelId=model,
|
||||
accept="*/*",
|
||||
contentType="application/json"
|
||||
input = input.replace(os.linesep, " ")
|
||||
body = json.dumps({"inputText": input})
|
||||
try:
|
||||
response = client.invoke_model(
|
||||
body=body,
|
||||
modelId=model,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
return response_body.get("embedding")
|
||||
except Exception as e:
|
||||
raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500)
|
||||
|
||||
def embedding(
|
||||
model: str,
|
||||
input: list,
|
||||
api_key: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
optional_params=None,
|
||||
encoding=None,
|
||||
):
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": {"model": model,
|
||||
"texts": input}},
|
||||
)
|
||||
|
||||
response_body = json.loads(response.get('body').read())
|
||||
## Embedding Call
|
||||
embeddings = [_embedding_func_single(model, i, optional_params) for i in input]
|
||||
|
||||
embedding_response = response_body["embedding"]
|
||||
|
||||
## Populate OpenAI compliant dictionary
|
||||
embedding_response = []
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
embedding_response.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding,
|
||||
}
|
||||
)
|
||||
model_response["object"] = "list"
|
||||
model_response["data"] = embedding_response
|
||||
model_response["model"] = model
|
||||
input_tokens = 0
|
||||
|
||||
input_tokens+=len(encoding.encode(input_str))
|
||||
input_str = "".join(input)
|
||||
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": input_tokens,
|
||||
"total_tokens": input_tokens,
|
||||
}
|
||||
input_tokens+=len(encoding.encode(input_str))
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
|
@ -549,4 +568,13 @@ def embedding(
|
|||
)
|
||||
model_response.usage = usage
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": {"model": model,
|
||||
"texts": input}},
|
||||
original_response=embeddings,
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
|
|
@ -95,7 +95,8 @@ def test_cohere_embedding3():
|
|||
def test_bedrock_embedding():
|
||||
try:
|
||||
response = embedding(
|
||||
model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data"]
|
||||
model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data",
|
||||
"lets test a second string for good measure"]
|
||||
)
|
||||
print(f"response:", response)
|
||||
except Exception as e:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue