forked from phoenix/litellm-mirror
Merge pull request #845 from canada4663/upstream-main
Added support for multiple embeddings via Bedrock
This commit is contained in:
commit
e4f1e2b138
2 changed files with 60 additions and 31 deletions
|
@ -483,13 +483,9 @@ def completion(
|
||||||
import traceback
|
import traceback
|
||||||
raise BedrockError(status_code=500, message=traceback.format_exc())
|
raise BedrockError(status_code=500, message=traceback.format_exc())
|
||||||
|
|
||||||
|
def _embedding_func_single(
|
||||||
|
|
||||||
def embedding(
|
|
||||||
model: str,
|
model: str,
|
||||||
input: list,
|
input: str,
|
||||||
logging_obj=None,
|
|
||||||
model_response=None,
|
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
encoding=None,
|
encoding=None,
|
||||||
):
|
):
|
||||||
|
@ -510,37 +506,60 @@ def embedding(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# translate to bedrock
|
input = input.replace(os.linesep, " ")
|
||||||
# bedrock only accepts (str) for inputText
|
body = json.dumps({"inputText": input})
|
||||||
if type(input) == list:
|
try:
|
||||||
if len(input) > 1: # input is a list with more than 1 elem, raise Exception, Bedrock only supports one element
|
|
||||||
raise BedrockError(message="Bedrock cannot embed() more than one string - len(input) must always == 1, input = ['hi from litellm']", status_code=400)
|
|
||||||
input_str = "".join(input)
|
|
||||||
|
|
||||||
response = client.invoke_model(
|
response = client.invoke_model(
|
||||||
body=json.dumps({
|
body=body,
|
||||||
"inputText": input_str
|
|
||||||
}),
|
|
||||||
modelId=model,
|
modelId=model,
|
||||||
accept="*/*",
|
accept="application/json",
|
||||||
contentType="application/json"
|
contentType="application/json",
|
||||||
|
)
|
||||||
|
response_body = json.loads(response.get("body").read())
|
||||||
|
return response_body.get("embedding")
|
||||||
|
except Exception as e:
|
||||||
|
raise BedrockError(message=f"Embedding Error with model {model}: {e}", status_code=500)
|
||||||
|
|
||||||
|
def embedding(
|
||||||
|
model: str,
|
||||||
|
input: list,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
logging_obj=None,
|
||||||
|
model_response=None,
|
||||||
|
optional_params=None,
|
||||||
|
encoding=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": {"model": model,
|
||||||
|
"texts": input}},
|
||||||
)
|
)
|
||||||
|
|
||||||
response_body = json.loads(response.get('body').read())
|
## Embedding Call
|
||||||
|
embeddings = [_embedding_func_single(model, i, optional_params) for i in input]
|
||||||
|
|
||||||
embedding_response = response_body["embedding"]
|
|
||||||
|
|
||||||
|
## Populate OpenAI compliant dictionary
|
||||||
|
embedding_response = []
|
||||||
|
for idx, embedding in enumerate(embeddings):
|
||||||
|
embedding_response.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding,
|
||||||
|
}
|
||||||
|
)
|
||||||
model_response["object"] = "list"
|
model_response["object"] = "list"
|
||||||
model_response["data"] = embedding_response
|
model_response["data"] = embedding_response
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
input_tokens = 0
|
input_tokens = 0
|
||||||
|
|
||||||
input_tokens+=len(encoding.encode(input_str))
|
input_str = "".join(input)
|
||||||
|
|
||||||
model_response["usage"] = {
|
input_tokens+=len(encoding.encode(input_str))
|
||||||
"prompt_tokens": input_tokens,
|
|
||||||
"total_tokens": input_tokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
usage = Usage(
|
usage = Usage(
|
||||||
prompt_tokens=input_tokens,
|
prompt_tokens=input_tokens,
|
||||||
|
@ -549,4 +568,13 @@ def embedding(
|
||||||
)
|
)
|
||||||
model_response.usage = usage
|
model_response.usage = usage
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": {"model": model,
|
||||||
|
"texts": input}},
|
||||||
|
original_response=embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
|
@ -95,7 +95,8 @@ def test_cohere_embedding3():
|
||||||
def test_bedrock_embedding():
|
def test_bedrock_embedding():
|
||||||
try:
|
try:
|
||||||
response = embedding(
|
response = embedding(
|
||||||
model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data"]
|
model="amazon.titan-embed-text-v1", input=["good morning from litellm, attempting to embed data",
|
||||||
|
"lets test a second string for good measure"]
|
||||||
)
|
)
|
||||||
print(f"response:", response)
|
print(f"response:", response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue